Skip to content

Commit

Permalink
Skip assignment optimization with local_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghonggeng committed Dec 16, 2024
1 parent 1ef2368 commit 4dc38df
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions paddle/cinn/optim/vectorize_for_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
"Expected _Tensor_ node in Store, but received nullptr."));
if (in_vectorize_ && node->is_addr_tensor() &&
tensor_can_vectorized_.count(tensor->name)) {
is_assignment_ = IsAssignment(node->value);
is_assignment_ = IsAssignment(node->value, tensor->name);
TensorVectorized(node, &node->indices, true);
}

Expand Down Expand Up @@ -500,7 +500,12 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {

// A store is considered to be a pure assignment statement only if the store
// value is load or cast(load).
bool IsAssignment(ir::Expr &value) { // NOLINT
bool IsAssignment(ir::Expr &value, // NOLINT
std::string tensor_store_name = "") {
if (tensor_store_name.find("_local") != std::string::npos) {
return false;
}

if (auto *cast_op = value.As<ir::Cast>()) {
return IsAssignment(cast_op->v());
}
Expand All @@ -515,6 +520,9 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
tensor_load,
::common::errors::InvalidArgument(
"Expected _Tensor_ node in Store, but received nullptr."));
if (tensor_load->name.find("_local") != std::string::npos) {
return false;
}
if (tensor_can_vectorized_.count(tensor_load->name) == 0) {
return false;
}
Expand Down

0 comments on commit 4dc38df

Please sign in to comment.