diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc index 20027809cc95b..0937dde5044da 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -829,45 +829,46 @@ bool GetCanApplyVectorize(const std::vector& op_compute_bodies) { }, /* uniq_target = */ false); - auto CheckTensorIsBroadcastAndContinuous = - [&](const std::vector& indices) { - int loop_idx = 0; - bool is_broadcast = false; - for (int i = 0; i < indices.size(); ++i) { - const ir::Expr& index = indices[i]; - if (index.is_constant()) { - is_broadcast = true; - continue; - } - - if (!index.is_var()) return false; - ir::Var iter_var = index.as_var_ref(); - if (!iter_var2value.count(iter_var)) { - return false; - } - ir::Expr iter_value = iter_var2value.at(iter_var); - PADDLE_ENFORCE_EQ( - iter_value.as_var() || iter_value.is_constant(), - true, - ::common::errors::PreconditionNotMet( - "Required iter_value shall be var or constant type.")); - for (; loop_idx < for_iters.size(); ++loop_idx) { - if (for_iters[loop_idx] == iter_value.as_var_ref()) { - break; - } - } - - if (loop_idx == for_iters.size()) { - return false; - } + auto CheckTensorIsBroadcastAndContinuous = [&](std::vector& indices) { + int loop_idx = 0; + bool is_broadcast = false; + for (int i = 0; i < indices.size(); ++i) { + ir::Expr& index = indices[i]; + cinn::optim::Simplify(&index); + if (index.is_constant()) { + is_broadcast = true; + continue; + } + + if (!index.is_var()) return false; + ir::Var iter_var = index.as_var_ref(); + if (!iter_var2value.count(iter_var)) { + return false; + } + ir::Expr iter_value = iter_var2value.at(iter_var); + PADDLE_ENFORCE_EQ( + iter_value.as_var() || iter_value.is_constant(), + true, + ::common::errors::PreconditionNotMet( + "Required iter_value shall be var or constant type.")); + for (; loop_idx < for_iters.size(); ++loop_idx) { + if (for_iters[loop_idx] == iter_value.as_var_ref()) { + break; } - if (is_broadcast) return true; + } + + if (loop_idx == for_iters.size()) { return false; - }; + } + } + if (is_broadcast || indices.size() < for_iters.size()) return true; + return false; + }; - auto CheckoutTensorIsContinuous = [&](const std::vector& indices) { + auto CheckoutTensorIsContinuous = [&](std::vector& indices) { for (int i = 0; i < indices.size(); ++i) { - const ir::Expr& index = indices[i]; + ir::Expr& index = indices[i]; + cinn::optim::Simplify(&index); if (index.is_constant()) return false; if (!index.is_var()) return false; ir::Var iter_var = index.as_var_ref(); diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc index e45d254571f1d..57c7b1c10cf10 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc @@ -230,22 +230,24 @@ TileConfigMap BuildVectorizeConfig( vectorize_factor = factor; const int elements_in_warp = kWarpSize * vectorize_factor; warp_nums = CeilDiv(spatial_numel, elements_in_warp); - warp_nums = Trim(warp_nums, 1, 16); + warp_nums = Trim(warp_nums, 1, 32); sp_thread_num = kWarpSize * warp_nums; if (CheckVectorize(spatial_numel, sp_thread_num, vectorize_factor)) { break; } } } + if (!can_vectorize) { base_info->can_apply_vectorize = false; return {}; } + int64_t sp_inner_num = [&]() -> int64_t { if (rd_thread_num > 1) return 1; spatial_numel = spatial_numel / sp_thread_num / vectorize_factor; - int64_t expected = spatial_numel / (sm_count * 16); - return Trim(expected, 1, 4); + int64_t expected = spatial_numel / (sm_count * 64); + return Trim(expected, 1, 1); }(); int64_t sp_upper_bound = base_info->spatial_numel > 1 ? kMaxNumel : 1; diff --git a/paddle/cinn/optim/vectorize_for_trans.cc b/paddle/cinn/optim/vectorize_for_trans.cc index 16f1756fd6dd3..d1ac5be0ee6a1 100644 --- a/paddle/cinn/optim/vectorize_for_trans.cc +++ b/paddle/cinn/optim/vectorize_for_trans.cc @@ -56,6 +56,59 @@ std::unordered_map CollectExprSymbols(Expr *x) { return std::move(mutator.GetSymbols()); } +class ForOpWithMultiScheduleBlockSupportVectorize + : public ir::IRMutator { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override { + auto *node = expr->As(); + PADDLE_ENFORCE_NOT_NULL( + node, + ::common::errors::InvalidArgument("The input expr should be a Block")); + IRMutator<>::Visit(op, expr); + if (in_vectorize_scope) { + for_op_blocks_.push_back(expr); + } + } + + void Visit(const ir::For *op, ir::Expr *expr) override { + auto *forloop = expr->As(); + if (forloop->is_vectorized()) in_vectorize_scope = true; + + IRMutator<>::Visit(op, expr); + + if (for_op_blocks_.size() > 1 && in_vectorize_scope) { + std::vector stmts; + for (auto block : for_op_blocks_) { + Var new_iterator( + cinn::common::UniqName(forloop->loop_var->name + "_s")); + + cinn::ir::ir_utils::IrReplaceVarBroadcast( + block, forloop->loop_var, Expr(new_iterator)); + + ir::Expr f_expr = ir::For::Make(new_iterator, + forloop->min, + forloop->extent, + forloop->for_type(), + forloop->device_api, + ir::Block::Make({*block}), + forloop->vectorize_info(), + forloop->bind_info()); + stmts.push_back(f_expr); + } + Expr block_expr = ir::Block::Make(stmts); + *expr = block_expr; + } + in_vectorize_scope = false; + for_op_blocks_.clear(); + } + + bool in_vectorize_scope{false}; + std::vector for_op_blocks_; +}; + class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator { public: ScheduleBlockTensorVectorizeTeller(Var iter_var, const int factor) @@ -539,6 +592,10 @@ class VectorizeForTransMutator : public ir::IRMutator { } // namespace void VectorizeForTrans(Expr *expr) { + ForOpWithMultiScheduleBlockSupportVectorize update; + VLOG(5) << "before multi schedule block deal with vectorize " << *expr; + update(expr); + VLOG(5) << "after multi schedule block deal with vectorize " << *expr; VectorizeForTransMutator collector; VLOG(5) << "before vectorize for trans " << *expr; collector(expr);