diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc index 1b0518ca1c1573..32ad05dbe9573f 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -68,7 +68,7 @@ bool ContainsVectorizableAxis(const ir::IRSchedule* sch, VLOG(4) << "Checking ContainsVectorizableAxis on block: [" << block_id << "], loop:\n" - << sch->GetModule().GetExprs().front() << "\n vectorize expr:]\n" + << sch->GetModule().GetExprs().front() << "\n vectorize expr:\n" << vectorize_expr; // Get all the lter values in the axis bind that contain a loop var and the diff --git a/paddle/cinn/optim/vectorize_for_trans.cc b/paddle/cinn/optim/vectorize_for_trans.cc index fdbdeb00a92631..2ff9006ae7b84e 100644 --- a/paddle/cinn/optim/vectorize_for_trans.cc +++ b/paddle/cinn/optim/vectorize_for_trans.cc @@ -294,6 +294,7 @@ class VectorizeForTransMutator : public ir::IRMutator { "Expected _Tensor_ node in Store, but received nullptr.")); if (in_vectorize_ && node->is_addr_tensor() && tensor_can_vectorized_.count(tensor->name)) { + only_assignment_store_ = IsOnlyAssignment(node->value); TensorVectorized(node, &node->indices, true); } @@ -340,8 +341,11 @@ class VectorizeForTransMutator : public ir::IRMutator { body_stmts.assign(update_cast_stmts_.begin(), update_cast_stmts_.end()); update_cast_stmts_.clear(); } - body_stmts.insert( - body_stmts.end(), unroll_body.begin(), unroll_body.end()); + + if (!only_assignment_store_) { + body_stmts.insert( + body_stmts.end(), unroll_body.begin(), unroll_body.end()); + } if (!update_store_stmts_.empty()) { body_stmts.insert(body_stmts.end(), @@ -355,6 +359,7 @@ class VectorizeForTransMutator : public ir::IRMutator { tensor2vectorized_vars_.clear(); tensor_can_vectorized_.clear(); in_vectorize_ = false; + only_assignment_store_ = false; } private: @@ -393,16 +398,18 @@ class VectorizeForTransMutator : public ir::IRMutator { AppendCast(node->tensor, *indices, is_store); } - auto vectorized_var = tensor2vectorized_vars_.at(tensor->name); - // substitute a new tensor with the vector name and dtype - auto t = vectorized_var->type().is_cpp_handle() - ? node->tensor->type().PointerOf() - : node->tensor->type(); - node->tensor = ir::Tensor(vectorized_var->name, - t, - {ir::Expr(vectorize_factor_)}, - {ir::Expr(vectorize_factor_)}, - tensor->operation); + if (!tensor2vectorized_vars_.empty()) { + auto vectorized_var = tensor2vectorized_vars_.at(tensor->name); + // substitute a new tensor with the vector name and dtype + auto t = vectorized_var->type().is_cpp_handle() + ? node->tensor->type().PointerOf() + : node->tensor->type(); + node->tensor = ir::Tensor(vectorized_var->name, + t, + {ir::Expr(vectorize_factor_)}, + {ir::Expr(vectorize_factor_)}, + tensor->operation); + } // remain the last iterative indice indices->assign({loop_var_}); } @@ -411,7 +418,6 @@ class VectorizeForTransMutator : public ir::IRMutator { const std::vector &indices, bool is_store) { auto *node = tensor.As(); - // generate the corresponding vector type Type scalar_type = tensor->type().ElementOf(); Type vector_type_ptr( @@ -429,7 +435,9 @@ class VectorizeForTransMutator : public ir::IRMutator { std::string vectorized_name = "vectorized_" + node->name + "_" + std::to_string(var_index_++); Var vectorized_var = ir::_Var_::Make(vectorized_name, vector_type); - tensor2vectorized_vars_.emplace(node->name, vectorized_var); + if (!only_assignment_store_) { + tensor2vectorized_vars_.emplace(node->name, vectorized_var); + } // generate a get_addr expr to get the address of the tensor Expr converted_tensor = ir::Load::Make(tensor, indices); @@ -444,22 +452,65 @@ class VectorizeForTransMutator : public ir::IRMutator { auto let = ir::Let::Make(vectorized_var, load); update_cast_stmts_.emplace_back(let); } else { - Var vectorized_ptr = - ir::_Var_::Make(vectorized_name + "_ptr", vector_type_ptr); - auto let1 = ir::Let::Make(vectorized_ptr, cast); - auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0)); - update_cast_stmts_.emplace_back(let1); - update_cast_stmts_.emplace_back(let2); - - auto t = ir::Tensor(vectorized_ptr->name, - node->type().PointerOf(), - {ir::Expr(vectorize_factor_)}, - {ir::Expr(vectorize_factor_)}, - node->operation); - auto store = - ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)}); - update_store_stmts_.emplace_back(store); - VLOG(5) << "Append a vectorized expr:" << store; + if (only_assignment_store_) { + Var vectorized_ptr = + ir::_Var_::Make(vectorized_name + "_ptr", vector_type_ptr); + auto let = ir::Let::Make(vectorized_ptr, cast); + update_cast_stmts_.emplace_back(let); + + auto t = ir::Tensor(vectorized_ptr->name, + node->type().PointerOf(), + {ir::Expr(vectorize_factor_)}, + {ir::Expr(vectorize_factor_)}, + node->operation); + + std::string load_vectorized_name = "vectorized_" + + only_assignment_load_name_ + "_" + + std::to_string(var_index_); + Var load_vectorized_var = + ir::_Var_::Make(load_vectorized_name, vector_type); + auto store = ir::Store::Make( + t, load_vectorized_var, {cinn::common::make_const(0)}); + update_store_stmts_.emplace_back(store); + VLOG(5) << "Append a IsOnlyAssignment vectorized expr:" << store; + } else { + Var vectorized_ptr = + ir::_Var_::Make(vectorized_name + "_ptr", vector_type_ptr); + auto let1 = ir::Let::Make(vectorized_ptr, cast); + auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0)); + update_cast_stmts_.emplace_back(let1); + update_cast_stmts_.emplace_back(let2); + + auto t = ir::Tensor(vectorized_ptr->name, + node->type().PointerOf(), + {ir::Expr(vectorize_factor_)}, + {ir::Expr(vectorize_factor_)}, + node->operation); + auto store = + ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)}); + update_store_stmts_.emplace_back(store); + VLOG(5) << "Append a vectorized expr:" << store; + } + } + } + + bool IsOnlyAssignment(ir::Expr &value) { // NOLINT + if (auto *cast_op = value.As()) { + return IsOnlyAssignment(cast_op->v()); + } else if (auto *load_op = value.As()) { + auto tensor_load = load_op->tensor.As(); + PADDLE_ENFORCE_NOT_NULL( + tensor_load, + ::common::errors::InvalidArgument( + "Expected _Tensor_ node in Store, but received nullptr.")); + if (tensor_can_vectorized_.count(tensor_load->name)) { + only_assignment_store_ = true; + only_assignment_load_name_ = tensor_load->name; + return true; + } + return false; + } else { + return false; } } @@ -472,6 +523,8 @@ class VectorizeForTransMutator : public ir::IRMutator { ir::Var loop_var_; bool in_vectorize_{false}; int var_index_{0}; + bool only_assignment_store_{false}; + std::string only_assignment_load_name_; }; } // namespace