Skip to content

Commit

Permalink
[CINN] support FusionGroup blocks vectorize check
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangX-21 committed Dec 13, 2024
1 parent 27603ca commit eafe635
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 15 deletions.
162 changes: 162 additions & 0 deletions paddle/cinn/hlir/framework/pir/trivial_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h"
#include "paddle/cinn/ir/group_schedule/config/group_tile_util.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/lang/placeholder.h"
Expand Down Expand Up @@ -759,6 +760,165 @@ std::vector<int64_t> GetLoopStrides(const ir::Expr& body,
return loop_strides;
}

bool GetCanApplyVectorize(const std::vector<ir::Expr>& op_compute_bodies) {
bool can_vectorize = true;
for (const auto& body : op_compute_bodies) {
ir::Expr expr_schedule_block_realize =
trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes
.GetSingle(body);
bool is_reduce =
ir::analyzer::IsReductionSBlock(expr_schedule_block_realize);
if (is_reduce) continue;
std::vector<ir::Expr> iter_values =
expr_schedule_block_realize.As<ir::ScheduleBlockRealize>()->iter_values;
const std::vector<ir::Var> for_iters =
trivial_fusion_detail::GetAllForIters(body);
std::unordered_map<ir::Var, ir::Expr> iter_var2value =
ir::analyzer::GetIterVarToValueOfSBlock(expr_schedule_block_realize);
std::unordered_map<std::string, std::vector<std::vector<Expr>>>
load_tensors_index;
ir::ir_utils::CollectIRNodesWithoutTensor(
expr_schedule_block_realize,
[&](const ir::Expr* expr) {
if (expr->As<ir::Load>()) {
auto* node = expr->As<ir::Load>();
PADDLE_ENFORCE_NOT_NULL(
node,
::common::errors::InvalidArgument(
"Expected Load node, but received nullptr."));
auto* tensor = node->tensor.As<ir::_Tensor_>();
PADDLE_ENFORCE_NOT_NULL(
tensor,
::common::errors::InvalidArgument(
"Expected _Tensor_ node in load, but received nullptr."));
load_tensors_index[tensor->name].push_back(node->indices);
return true;
}
return false;
},
/* uniq_target = */ false);

std::unordered_map<std::string, std::vector<std::vector<Expr>>>
store_tensors_index;
ir::ir_utils::CollectIRNodesWithoutTensor(
expr_schedule_block_realize,
[&](const ir::Expr* expr) {
if (expr->As<ir::Store>()) {
auto* node = expr->As<ir::Store>();
PADDLE_ENFORCE_NOT_NULL(
node,
::common::errors::InvalidArgument(
"Expected Load node, but received nullptr."));
auto* tensor = node->tensor.As<ir::_Tensor_>();
PADDLE_ENFORCE_NOT_NULL(
tensor,
::common::errors::InvalidArgument(
"Expected _Tensor_ node in load, but received nullptr."));
store_tensors_index[tensor->name].push_back(node->indices);
return true;
}
return false;
},
/* uniq_target = */ false);

auto CheckTensorIsBroadcastAndContinuous =
[&](const std::vector<Expr>& indices) {
int loop_idx = 0;
bool is_broadcast = false;
for (int i = 0; i < indices.size(); ++i) {
const ir::Expr& index = indices[i];
if (index.is_constant()) {
is_broadcast = true;
continue;
}

if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
if (!iter_var2value.count(iter_var)) {
return false;
}
ir::Expr iter_value = iter_var2value.at(iter_var);
PADDLE_ENFORCE_EQ(
iter_value.as_var() || iter_value.is_constant(),
true,
::common::errors::PreconditionNotMet(
"Required iter_value shall be var or constant type."));
for (; loop_idx < for_iters.size(); ++loop_idx) {
if (for_iters[loop_idx] == iter_value.as_var_ref()) {
break;
}
}

if (loop_idx == for_iters.size()) {
return false;
}
}
if (is_broadcast) return true;
return false;
};

auto CheckoutTensorIsContinuous = [&](const std::vector<Expr>& indices) {
for (int i = 0; i < indices.size(); ++i) {
const ir::Expr& index = indices[i];
if (index.is_constant()) return false;
if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
if (!iter_var2value.count(iter_var)) {
return false;
}
ir::Expr iter_value = iter_var2value.at(iter_var);
PADDLE_ENFORCE_EQ(
iter_value.as_var() || iter_value.is_constant(),
true,
::common::errors::PreconditionNotMet(
"Required iter_value shall be var or constant type."));
if (for_iters[i] != iter_value.as_var_ref()) {
return false;
}
}
return true;
};

// load tensor information
std::unordered_set<std::string> is_broadcast_continuous_tensors;
std::unordered_set<std::string> is_continuous_tensors;
// bool can_vectorize = true;
for (const auto& tensor_index : load_tensors_index) {
for (auto indexs : tensor_index.second) {
if (CheckTensorIsBroadcastAndContinuous(indexs)) {
is_broadcast_continuous_tensors.insert(tensor_index.first);
continue;
}
if (CheckoutTensorIsContinuous(indexs)) {
is_continuous_tensors.insert(tensor_index.first);
continue;
}
can_vectorize = false;
break;
}
}
// store tensor information
for (const auto& tensor_index : store_tensors_index) {
for (auto indexs : tensor_index.second) {
if (CheckTensorIsBroadcastAndContinuous(indexs)) {
is_broadcast_continuous_tensors.insert(tensor_index.first);
continue;
}

if (CheckoutTensorIsContinuous(indexs)) {
is_continuous_tensors.insert(tensor_index.first);
continue;
}
can_vectorize = false;
break;
}
}
if (!can_vectorize) break;
}

return can_vectorize;
}

std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
const std::vector<ir::Expr>& op_compute_bodies) {
using trivial_fusion_detail::AppendBound;
Expand Down Expand Up @@ -841,6 +1001,8 @@ std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis);
}

group_info->can_apply_vectorize = GetCanApplyVectorize(op_compute_bodies);

VLOG(4) << group_info->DebugPrint();
return group_info;
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/framework/pir/trivial_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,16 @@ struct FusionGroupInfo {
std::vector<int64_t> reduce_axis;
std::vector<std::string> reduce_var_name;
bool can_apply_grid_reduce;
bool can_apply_vectorize;

std::string DebugPrint() {
std::stringstream ss;
ss << "GroupInfo\nloop_ranges: " << cinn::utils::Join(loop_ranges, " ")
<< "\nloop_strides: " << cinn::utils::Join(loop_strides, ", ")
<< "\nreduce_axis: " << cinn::utils::Join(reduce_axis, " ")
<< "\nreduce_var_name: " << cinn::utils::Join(reduce_var_name, " ")
<< "\ncan_apply_grid_reduce: " << can_apply_grid_reduce;
<< "\ncan_apply_grid_reduce: " << can_apply_grid_reduce
<< "\ncan_apply_vectorize: " << can_apply_vectorize;
return ss.str();
}
};
Expand Down
13 changes: 8 additions & 5 deletions paddle/cinn/ir/group_schedule/config/group_tile_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ std::shared_ptr<ScheduleConfig::BaseInfo> InitBasicInfo(
base_info->data_rank = group_info->loop_ranges.size();
base_info->loop_strides = group_info->loop_strides;
base_info->can_apply_grid_reduce = group_info->can_apply_grid_reduce;
base_info->can_apply_vectorize = group_info->can_apply_vectorize;

std::set<int64_t> reduce_dim_loc;
for (int64_t dim : group_info->reduce_axis) {
Expand Down Expand Up @@ -184,6 +185,7 @@ std::shared_ptr<ScheduleConfig::BaseInfo> InitBasicInfo(
TileConfigMap BuildVectorizeConfig(
const std::shared_ptr<ScheduleConfig::BaseInfo>& base_info,
const common::Target& target) {
if (!base_info->can_apply_vectorize) return {};
// current only support [S, R] and [S]
const int iters_dim = base_info->iter_space_type.size();
if (iters_dim > 2) return {};
Expand All @@ -194,12 +196,12 @@ TileConfigMap BuildVectorizeConfig(
ReduceMethod reduce_method = NoneReduceMethod();
int vectorize_factor = 1;
const std::vector<int> vectorize_factors{4, 2};

bool can_vectorize = false;
auto const CheckVectorize = [&](int nums, int threads, int factor) {
const int deal_elements_in_warp = threads * factor;
if (nums % deal_elements_in_warp == 0) {
vectorize_factor = factor;
base_info->enable_vectorize = true;
can_vectorize = true;
return true;
}
return false;
Expand Down Expand Up @@ -235,9 +237,10 @@ TileConfigMap BuildVectorizeConfig(
}
}
}

if (!base_info->enable_vectorize) return {};

if (!can_vectorize) {
base_info->can_apply_vectorize = false;
return {};
}
int64_t sp_inner_num = [&]() -> int64_t {
if (rd_thread_num > 1) return 1;
spatial_numel = spatial_numel / sp_thread_num / vectorize_factor;
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ir/group_schedule/config/group_tile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct ScheduleConfig {
bool has_dynamic_spatial{false};
bool has_dynamic_reduce{false};
bool can_apply_grid_reduce{false};
bool enable_vectorize{false};
bool can_apply_vectorize{false};
IterSpaceType iter_space_type;
};

Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ class ComputeInlineTactic final : public ScheduleTactic {
private:
std::unordered_set<std::string> output_names_;
cinn::common::Target target_;
bool enable_vectorize_{false};
};

void ComputeInlineTactic::Init(ScheduleContext* context) {
output_names_ = context->output_names;
target_ = context->target;
enable_vectorize_ = context->config.base_info->enable_vectorize;
}

void ComputeInlineTactic::Apply(ir::IRSchedule* sch,
Expand All @@ -53,7 +51,6 @@ void ComputeInlineTactic::Apply(ir::IRSchedule* sch,
// return;
// }
// compute inline tactic not work, when apply vectorize in current schedule
if (enable_vectorize_) return;
auto_schedule::AutoInline inliner(target_, output_names_);
VLOG(6) << "try ComputeInline on: " << block_id
<< ", before ComputeInline, func body: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ bool ContainsVectorizableAxis(const ir::IRSchedule* sch,

bool ScheduleBlockEnableVectorize(const ScheduleConfig& config,
const std::string& block_id) {
if (!config.base_info->enable_vectorize) return false;
if (!config.base_info->can_apply_vectorize) return false;

// currently, dont't support vectorize for SplitedTransformTensor
// ScheduleBlock
if (ir::IsSplitTransformTensorName(block_id)) return false;
// if (ir::IsSplitTransformTensorName(block_id)) return false;

if (!UseContinuousDataTile(config)) return false;

Expand Down
14 changes: 11 additions & 3 deletions paddle/cinn/optim/vectorize_for_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,23 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
return false;
}

cinn::optim::Simplify(&offset);
Expr origin_offset = ir::ir_utils::IRCopy(offset);
// only with vectorize axis offset
auto only_vectorize_axis_offset = ir::ir_utils::IRCopy(offset);
for (const auto &[key, value] : var_symbols) {
if (key == iter_var_->name) continue;
cinn::ir::ir_utils::IrReplaceVarBroadcast(
&only_vectorize_axis_offset, Expr(value), Expr(int32_t(0)));
}

cinn::optim::Simplify(&only_vectorize_axis_offset);
Expr origin_offset = ir::ir_utils::IRCopy(only_vectorize_axis_offset);
cinn::ir::ir_utils::IrReplaceVarBroadcast(
&origin_offset, Expr(iter_var_), Expr(int32_t(0)));
cinn::optim::Simplify(&origin_offset);
bool is_zero = true;
bool is_continous = true;
for (int i = 1; i < factor_; i++) {
Expr next = ir::ir_utils::IRCopy(offset);
Expr next = ir::ir_utils::IRCopy(only_vectorize_axis_offset);
cinn::ir::ir_utils::IrReplaceVarBroadcast(
&next, Expr(iter_var_), Expr(int32_t(i)));
cinn::optim::Simplify(&next);
Expand Down

0 comments on commit eafe635

Please sign in to comment.