Skip to content

Commit

Permalink
Simplifying pure assignment statements in vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghonggeng committed Dec 11, 2024
1 parent bc5968b commit 27603ca
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ bool ContainsVectorizableAxis(const ir::IRSchedule* sch,

VLOG(4) << "Checking ContainsVectorizableAxis on block: [" << block_id
<< "], loop:\n"
<< sch->GetModule().GetExprs().front() << "\n vectorize expr:]\n"
<< sch->GetModule().GetExprs().front() << "\n vectorize expr:\n"
<< vectorize_expr;

// Get all the lter values in the axis bind that contain a loop var and the
Expand Down
94 changes: 74 additions & 20 deletions paddle/cinn/optim/vectorize_for_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
"Expected _Tensor_ node in Store, but received nullptr."));
if (in_vectorize_ && node->is_addr_tensor() &&
tensor_can_vectorized_.count(tensor->name)) {
is_assignment_ = IsAssignment(node->value);
TensorVectorized(node, &node->indices, true);
}

Expand Down Expand Up @@ -340,8 +341,11 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
body_stmts.assign(update_cast_stmts_.begin(), update_cast_stmts_.end());
update_cast_stmts_.clear();
}
body_stmts.insert(
body_stmts.end(), unroll_body.begin(), unroll_body.end());

if (!is_assignment_) {
body_stmts.insert(
body_stmts.end(), unroll_body.begin(), unroll_body.end());
}

if (!update_store_stmts_.empty()) {
body_stmts.insert(body_stmts.end(),
Expand All @@ -355,6 +359,7 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
tensor2vectorized_vars_.clear();
tensor_can_vectorized_.clear();
in_vectorize_ = false;
is_assignment_ = false;
}

private:
Expand Down Expand Up @@ -393,16 +398,18 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
AppendCast(node->tensor, *indices, is_store);
}

auto vectorized_var = tensor2vectorized_vars_.at(tensor->name);
// substitute a new tensor with the vector name and dtype
auto t = vectorized_var->type().is_cpp_handle()
? node->tensor->type().PointerOf()
: node->tensor->type();
node->tensor = ir::Tensor(vectorized_var->name,
t,
{ir::Expr(vectorize_factor_)},
{ir::Expr(vectorize_factor_)},
tensor->operation);
if (!is_assignment_) {
auto vectorized_var = tensor2vectorized_vars_.at(tensor->name);
// substitute a new tensor with the vector name and dtype
auto t = vectorized_var->type().is_cpp_handle()
? node->tensor->type().PointerOf()
: node->tensor->type();
node->tensor = ir::Tensor(vectorized_var->name,
t,
{ir::Expr(vectorize_factor_)},
{ir::Expr(vectorize_factor_)},
tensor->operation);
}
// remain the last iterative indice
indices->assign({loop_var_});
}
Expand All @@ -411,7 +418,6 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
const std::vector<ir::Expr> &indices,
bool is_store) {
auto *node = tensor.As<ir::_Tensor_>();

// generate the corresponding vector type
Type scalar_type = tensor->type().ElementOf();
Type vector_type_ptr(
Expand All @@ -429,7 +435,9 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
std::string vectorized_name =
"vectorized_" + node->name + "_" + std::to_string(var_index_++);
Var vectorized_var = ir::_Var_::Make(vectorized_name, vector_type);
tensor2vectorized_vars_.emplace(node->name, vectorized_var);
if (!is_assignment_) {
tensor2vectorized_vars_.emplace(node->name, vectorized_var);
}

// generate a get_addr expr to get the address of the tensor
Expr converted_tensor = ir::Load::Make(tensor, indices);
Expand All @@ -447,20 +455,64 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
Var vectorized_ptr =
ir::_Var_::Make(vectorized_name + "_ptr", vector_type_ptr);
auto let1 = ir::Let::Make(vectorized_ptr, cast);
auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0));
update_cast_stmts_.emplace_back(let1);
update_cast_stmts_.emplace_back(let2);

auto t = ir::Tensor(vectorized_ptr->name,
node->type().PointerOf(),
{ir::Expr(vectorize_factor_)},
{ir::Expr(vectorize_factor_)},
node->operation);
auto store =
ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)});
update_store_stmts_.emplace_back(store);
VLOG(5) << "Append a vectorized expr:" << store;

if (is_assignment_) {
std::string load_vectorized_name = "vectorized_" +
assignment_tensor_name_ + "_" +
std::to_string(var_index_);
Var load_vectorized_var =
ir::_Var_::Make(load_vectorized_name, vector_type);
auto store = ir::Store::Make(
t, load_vectorized_var, {cinn::common::make_const(0)});
update_store_stmts_.emplace_back(store);
VLOG(5) << "Append a assignment vectorized expr:" << store;
} else {
auto let2 = ir::Let::Make(vectorized_var, ir::Expr(0));
update_cast_stmts_.emplace_back(let2);

auto t = ir::Tensor(vectorized_ptr->name,
node->type().PointerOf(),
{ir::Expr(vectorize_factor_)},
{ir::Expr(vectorize_factor_)},
node->operation);
auto store =
ir::Store::Make(t, vectorized_var, {cinn::common::make_const(0)});
update_store_stmts_.emplace_back(store);
VLOG(5) << "Append a vectorized expr:" << store;
}
}
}

// A store is considered to be a pure assignment statement only if the store
// value is load or cast(load).
bool IsAssignment(ir::Expr &value) { // NOLINT
if (auto *cast_op = value.As<ir::Cast>()) {
return IsAssignment(cast_op->v());
}

auto *load_op = value.As<ir::Load>();
if (!load_op) {
return false;
}

auto tensor_load = load_op->tensor.As<ir::_Tensor_>();
PADDLE_ENFORCE_NOT_NULL(
tensor_load,
::common::errors::InvalidArgument(
"Expected _Tensor_ node in Store, but received nullptr."));
if (tensor_can_vectorized_.count(tensor_load->name) == 0) {
return false;
}
is_assignment_ = true;
assignment_tensor_name_ = tensor_load->name;
return true;
}

std::vector<ir::Expr> update_cast_stmts_;
Expand All @@ -472,6 +524,8 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
ir::Var loop_var_;
bool in_vectorize_{false};
int var_index_{0};
bool is_assignment_{false};
std::string assignment_tensor_name_;
};

} // namespace
Expand Down

0 comments on commit 27603ca

Please sign in to comment.