diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc index 1b0518ca1c157..32ad05dbe9573 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -68,7 +68,7 @@ bool ContainsVectorizableAxis(const ir::IRSchedule* sch, VLOG(4) << "Checking ContainsVectorizableAxis on block: [" << block_id << "], loop:\n" - << sch->GetModule().GetExprs().front() << "\n vectorize expr:]\n" + << sch->GetModule().GetExprs().front() << "\n vectorize expr:\n" << vectorize_expr; // Get all the lter values in the axis bind that contain a loop var and the diff --git a/paddle/cinn/optim/vectorize_for_trans.cc b/paddle/cinn/optim/vectorize_for_trans.cc index fdbdeb00a9263..1c16c81918661 100644 --- a/paddle/cinn/optim/vectorize_for_trans.cc +++ b/paddle/cinn/optim/vectorize_for_trans.cc @@ -294,6 +294,7 @@ class VectorizeForTransMutator : public ir::IRMutator { "Expected _Tensor_ node in Store, but received nullptr.")); if (in_vectorize_ && node->is_addr_tensor() && tensor_can_vectorized_.count(tensor->name)) { + is_assignment_ = IsAssignment(node->value); TensorVectorized(node, &node->indices, true); } @@ -340,8 +341,11 @@ class VectorizeForTransMutator : public ir::IRMutator { body_stmts.assign(update_cast_stmts_.begin(), update_cast_stmts_.end()); update_cast_stmts_.clear(); } - body_stmts.insert( - body_stmts.end(), unroll_body.begin(), unroll_body.end()); + + if (!is_assignment_) { + body_stmts.insert( + body_stmts.end(), unroll_body.begin(), unroll_body.end()); + } if (!update_store_stmts_.empty()) { body_stmts.insert(body_stmts.end(), @@ -355,6 +359,7 @@ class VectorizeForTransMutator : public ir::IRMutator { tensor2vectorized_vars_.clear(); tensor_can_vectorized_.clear(); in_vectorize_ = false; + is_assignment_ = false; } private: @@ -393,16 +398,18 @@ class VectorizeForTransMutator : public ir::IRMutator { AppendCast(node->tensor, *indices, is_store); } - auto vectorized_var = tensor2vectorized_vars_.at(tensor->name); - // substitute a new tensor with the vector name and dtype - auto t = vectorized_var->type().is_cpp_handle() - ? node->tensor->type().PointerOf() - : node->tensor->type(); - node->tensor = ir::Tensor(vectorized_var->name, - t, - {ir::Expr(vectorize_factor_)}, - {ir::Expr(vectorize_factor_)}, - tensor->operation); + if (!is_assignment_) { + auto vectorized_var = tensor2vectorized_vars_.at(tensor->name); + // substitute a new tensor with the vector name and dtype + auto t = vectorized_var->type().is_cpp_handle() + ? node->tensor->type().PointerOf() + : node->tensor->type(); + node->tensor = ir::Tensor(vectorized_var->name, + t, + {ir::Expr(vectorize_factor_)}, + {ir::Expr(vectorize_factor_)}, + tensor->operation); + } // remain the last iterative indice indices->assign({loop_var_}); } @@ -411,7 +418,6 @@ class VectorizeForTransMutator : public ir::IRMutator { const std::vector &indices, bool is_store) { auto *node = tensor.As(); - // generate the corresponding vector type Type scalar_type = tensor->type().ElementOf(); Type vector_type_ptr( @@ -429,7 +435,9 @@ class VectorizeForTransMutator : public ir::IRMutator { std::string vectorized_name = "vectorized_" + node->name + "_" + std::to_string(var_index_++); Var vectorized_var = ir::_Var_::Make(vectorized_name, vector_type); - tensor2vectorized_vars_.emplace(node->name, vectorized_var); + if (!is_assignment_) { + tensor2vectorized_vars_.emplace(node->name, vectorized_var); + } // generate a get_addr expr to get the address of the tensor Expr converted_tensor = ir::Load::Make(tensor, indices); @@ -447,20 +455,64 @@ class VectorizeForTransMutator : public ir::IRMutator { Var vectorized_ptr = ir::_Var_::Make(vectorized_name + "_ptr", vector_type_ptr); auto let1 = ir::Let::Make(vectorized_ptr, cast); - auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0)); update_cast_stmts_.emplace_back(let1); - update_cast_stmts_.emplace_back(let2); auto t = ir::Tensor(vectorized_ptr->name, node->type().PointerOf(), {ir::Expr(vectorize_factor_)}, {ir::Expr(vectorize_factor_)}, node->operation); - auto store = - ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)}); - update_store_stmts_.emplace_back(store); - VLOG(5) << "Append a vectorized expr:" << store; + + if (is_assignment_) { + std::string load_vectorized_name = "vectorized_" + + assignment_tensor_name_ + "_" + + std::to_string(var_index_); + Var load_vectorized_var = + ir::_Var_::Make(load_vectorized_name, vector_type); + auto store = ir::Store::Make( + t, load_vectorized_var, {cinn::common::make_const(0)}); + update_store_stmts_.emplace_back(store); + VLOG(5) << "Append a assignment vectorized expr:" << store; + } else { + auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0)); + update_cast_stmts_.emplace_back(let2); + + auto t = ir::Tensor(vectorized_ptr->name, + node->type().PointerOf(), + {ir::Expr(vectorize_factor_)}, + {ir::Expr(vectorize_factor_)}, + node->operation); + auto store = + ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)}); + update_store_stmts_.emplace_back(store); + VLOG(5) << "Append a vectorized expr:" << store; + } + } + } + + // A store is considered to be a pure assignment statement only if the store + // value is load or cast(load). + bool IsAssignment(ir::Expr &value) { // NOLINT + if (auto *cast_op = value.As()) { + return IsAssignment(cast_op->v()); } + + auto *load_op = value.As(); + if (!load_op) { + return false; + } + + auto tensor_load = load_op->tensor.As(); + PADDLE_ENFORCE_NOT_NULL( + tensor_load, + ::common::errors::InvalidArgument( + "Expected _Tensor_ node in Store, but received nullptr.")); + if (tensor_can_vectorized_.count(tensor_load->name) == 0) { + return false; + } + is_assignment_ = true; + assignment_tensor_name_ = tensor_load->name; + return true; } std::vector update_cast_stmts_; @@ -472,6 +524,8 @@ class VectorizeForTransMutator : public ir::IRMutator { ir::Var loop_var_; bool in_vectorize_{false}; int var_index_{0}; + bool is_assignment_{false}; + std::string assignment_tensor_name_; }; } // namespace