Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] apply vectorize Primitive in IRSchedule #69732

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
fe5a015
[CINN] apply vectorize Primitive in IRSchedule
ZhangX-21 Nov 26, 2024
fb9602b
Check if iter_var contains the loop_var of the axis to be vectorized.
zhanghonggeng Nov 27, 2024
0d88203
update tiling with vectorize information
ZhangX-21 Nov 27, 2024
0082c5f
[CINN] checkout tensor is continuous in vectorize axis loop
ZhangX-21 Dec 4, 2024
e8d3fba
[CINN] vectorize tensor support more data type
ZhangX-21 Dec 5, 2024
bc5968b
[CINN] support select op situation in vectorize and search vectorize …
ZhangX-21 Dec 9, 2024
27603ca
Simplifying pure assignment statements in vectorization
zhanghonggeng Dec 10, 2024
eafe635
[CINN] support FusionGroup blocks vectorize check
ZhangX-21 Dec 13, 2024
1ef2368
skip GetCanApplyVectorize check when ChildScheduleBlockRealizes size …
zhanghonggeng Dec 16, 2024
2dad2d3
Skip assignment optimization with local_buffer
zhanghonggeng Dec 16, 2024
b9a0da5
[CINN] support ForOp with muti schedule blocks in vectorize!
ZhangX-21 Dec 17, 2024
6f4f6cc
[CINN] add FLAGS_cinn_enable_vectorize flag and kernel config
ZhangX-21 Dec 18, 2024
7171810
Revet Skip assignment optimization with local_buffer
zhanghonggeng Dec 18, 2024
884d49e
Assignment with cast cannot be simplified
zhanghonggeng Dec 18, 2024
9c46320
fix reduce in GetCanApplyVectorize
zhanghonggeng Dec 20, 2024
7526a53
Add CheckSmUtilization in BuildVectorizeConfig
zhanghonggeng Dec 23, 2024
84ed69f
skip if iter_value is not var or constant in CheckTensorIsBroadcastAn…
zhanghonggeng Dec 25, 2024
acd9902
add SmUtilization check for Reduce Region
zhanghonggeng Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions paddle/cinn/hlir/framework/pir/trivial_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/config/group_tile_util.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/lang/placeholder.h"
Expand All @@ -37,6 +38,7 @@
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

PD_DECLARE_bool(cinn_enable_grid_reduce);
PD_DECLARE_bool(cinn_enable_vectorize);

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -759,6 +761,184 @@ std::vector<int64_t> GetLoopStrides(const ir::Expr& body,
return loop_strides;
}

VectorizeInfo GetCanApplyVectorize(
const std::vector<ir::Expr>& op_compute_bodies) {
bool can_vectorize = true;
bool has_if_else_op = false;
for (const auto& body : op_compute_bodies) {
using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes;
using trivial_fusion_detail::ExprSetFinderUtils::ExprSetFinder;
ExprSetFinder finder =
ChildScheduleBlockRealizes * ExprSetFinder::GetIdentity();
const auto& o = finder(body);

if (o.size() != 1) {
continue;
}

ir::Expr expr_schedule_block_realize = *o.begin();
bool is_reduce =
ir::analyzer::IsReductionSBlock(expr_schedule_block_realize);
if (is_reduce) continue;
std::vector<ir::Expr> iter_values =
expr_schedule_block_realize.As<ir::ScheduleBlockRealize>()->iter_values;
const std::vector<ir::Var> for_iters =
trivial_fusion_detail::GetAllForIters(body);
std::unordered_map<ir::Var, ir::Expr> iter_var2value =
ir::analyzer::GetIterVarToValueOfSBlock(expr_schedule_block_realize);
std::unordered_map<std::string, std::vector<std::vector<Expr>>>
load_tensors_index;
ir::ir_utils::CollectIRNodesWithoutTensor(
expr_schedule_block_realize,
[&](const ir::Expr* expr) {
if (expr->As<ir::Load>()) {
auto* node = expr->As<ir::Load>();
PADDLE_ENFORCE_NOT_NULL(
node,
::common::errors::InvalidArgument(
"Expected Load node, but received nullptr."));
auto* tensor = node->tensor.As<ir::_Tensor_>();
PADDLE_ENFORCE_NOT_NULL(
tensor,
::common::errors::InvalidArgument(
"Expected _Tensor_ node in load, but received nullptr."));
load_tensors_index[tensor->name].push_back(node->indices);
return true;
}
return false;
},
/* uniq_target = */ false);

ir::ir_utils::CollectIRNodesWithoutTensor(
expr_schedule_block_realize,
[&](const ir::Expr* expr) {
if (expr->As<ir::IfThenElse>()) {
auto* node = expr->As<ir::IfThenElse>();
PADDLE_ENFORCE_NOT_NULL(
node,
::common::errors::InvalidArgument(
"Expected Load node, but received nullptr."));
has_if_else_op = true;
return true;
}
return false;
},
/* uniq_target = */ false);

std::unordered_map<std::string, std::vector<std::vector<Expr>>>
store_tensors_index;
ir::ir_utils::CollectIRNodesWithoutTensor(
expr_schedule_block_realize,
[&](const ir::Expr* expr) {
if (expr->As<ir::Store>()) {
auto* node = expr->As<ir::Store>();
PADDLE_ENFORCE_NOT_NULL(
node,
::common::errors::InvalidArgument(
"Expected Load node, but received nullptr."));
auto* tensor = node->tensor.As<ir::_Tensor_>();
PADDLE_ENFORCE_NOT_NULL(
tensor,
::common::errors::InvalidArgument(
"Expected _Tensor_ node in load, but received nullptr."));
store_tensors_index[tensor->name].push_back(node->indices);
return true;
}
return false;
},
/* uniq_target = */ false);

auto CheckTensorIsBroadcastAndContinuous = [&](std::vector<Expr>& indices) {
int loop_idx = 0;
bool is_broadcast = false;
for (int i = 0; i < indices.size(); ++i) {
ir::Expr& index = indices[i];
cinn::optim::Simplify(&index);
if (index.is_constant()) {
is_broadcast = true;
continue;
}

if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
if (!iter_var2value.count(iter_var)) {
return false;
}
ir::Expr iter_value = iter_var2value.at(iter_var);
if (!iter_value.as_var() && !iter_value.is_constant()) return false;
for (; loop_idx < for_iters.size(); ++loop_idx) {
if (for_iters[loop_idx] == iter_value.as_var_ref()) {
break;
}
}

if (loop_idx == for_iters.size()) {
return false;
}
}
if (is_broadcast || indices.size() < for_iters.size()) return true;
return false;
};

auto CheckoutTensorIsContinuous = [&](std::vector<Expr>& indices) {
for (int i = 0; i < indices.size(); ++i) {
ir::Expr& index = indices[i];
cinn::optim::Simplify(&index);
if (index.is_constant()) return false;
if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
if (!iter_var2value.count(iter_var)) {
return false;
}
ir::Expr iter_value = iter_var2value.at(iter_var);
if (!iter_value.as_var() && !iter_value.is_constant()) return false;
if (for_iters[i] != iter_value.as_var_ref()) {
return false;
}
}
return true;
};

// load tensor information
std::unordered_set<std::string> is_broadcast_continuous_tensors;
std::unordered_set<std::string> is_continuous_tensors;
// bool can_vectorize = true;
for (const auto& tensor_index : load_tensors_index) {
for (auto indexs : tensor_index.second) {
if (CheckTensorIsBroadcastAndContinuous(indexs)) {
is_broadcast_continuous_tensors.insert(tensor_index.first);
continue;
}
if (CheckoutTensorIsContinuous(indexs)) {
is_continuous_tensors.insert(tensor_index.first);
continue;
}
can_vectorize = false;
break;
}
}
// store tensor information
for (const auto& tensor_index : store_tensors_index) {
for (auto indexs : tensor_index.second) {
if (CheckTensorIsBroadcastAndContinuous(indexs)) {
is_broadcast_continuous_tensors.insert(tensor_index.first);
continue;
}

if (CheckoutTensorIsContinuous(indexs)) {
is_continuous_tensors.insert(tensor_index.first);
continue;
}
can_vectorize = false;
break;
}
}
if (!can_vectorize) break;
}

return {can_vectorize, has_if_else_op};
}

std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
const std::vector<ir::Expr>& op_compute_bodies) {
using trivial_fusion_detail::AppendBound;
Expand Down Expand Up @@ -841,6 +1021,10 @@ std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis);
}

if (FLAGS_cinn_enable_vectorize) {
group_info->vectorize_info = GetCanApplyVectorize(op_compute_bodies);
}

VLOG(4) << group_info->DebugPrint();
return group_info;
}
Expand Down
9 changes: 8 additions & 1 deletion paddle/cinn/hlir/framework/pir/trivial_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,28 @@ std::vector<ir::Var> GetAllIterVars(const ir::Expr& expr);
std::vector<ir::Var> GetAllForIters(const ir::Expr& expr);

} // namespace trivial_fusion_detail
struct VectorizeInfo {
bool can_apply_vectorize;
bool has_if_else_op;
};

struct FusionGroupInfo {
std::vector<int64_t> loop_ranges;
std::vector<int64_t> loop_strides;
std::vector<int64_t> reduce_axis;
std::vector<std::string> reduce_var_name;
bool can_apply_grid_reduce;
VectorizeInfo vectorize_info;

std::string DebugPrint() {
std::stringstream ss;
ss << "GroupInfo\nloop_ranges: " << cinn::utils::Join(loop_ranges, " ")
<< "\nloop_strides: " << cinn::utils::Join(loop_strides, ", ")
<< "\nreduce_axis: " << cinn::utils::Join(reduce_axis, " ")
<< "\nreduce_var_name: " << cinn::utils::Join(reduce_var_name, " ")
<< "\ncan_apply_grid_reduce: " << can_apply_grid_reduce;
<< "\ncan_apply_grid_reduce: " << can_apply_grid_reduce
<< "\ncan_apply_vectorize: " << vectorize_info.can_apply_vectorize
<< "\nhas_if_else_op: " << vectorize_info.has_if_else_op;
return ss.str();
}
};
Expand Down
Loading
Loading