Skip to content

Commit

Permalink
Check if iter_var contains the loop_var of the axis to be vectorized.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghonggeng committed Nov 27, 2024
1 parent bdffaac commit 45cce6d
Showing 1 changed file with 140 additions and 22 deletions.
162 changes: 140 additions & 22 deletions paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,100 @@ bool UseContinuousDataTile(const ScheduleConfig& config) {
return false;
}

using BoundVariableMap = std::unordered_map<std::string, std::vector<Var>>;

/*
* Check if the current loop variable containing the vectorize axis
* is present in the iter values of the axis bind within the loop body.
* If it is present, the loop cannot be vectorized.
* For example, the following loop cannot be vectorized:
*
* serial for (append_var_0_append_var_1_append_var_2_fused, 0ll, 2304ll)
* {
* ScheduleBlock(var_30) {
* i0_29 = axis.bind(append_var_3_i_append_var_4_append_var_5_fused)
* if (((append_var_3_i_append_var_4_append_var_5_fused / 32ll) == 0ll))
* var_30[i0_29] = ((0.899999976f * var_17[i0_29]) + (0.100000024f *
* var_2[0ll, i0_29, 0ll, 0ll]))
* }
* }
*/
bool ContainsVectorizableAxis(const ir::IRSchedule* sch,
const size_t vectorize_axis,
const std::string& block_id) {
auto loops = sch->GetLoops(block_id);
auto vectorize_expr = loops[vectorize_axis];

VLOG(4) << "Checking ContainsVectorizableAxis on block: [" << block_id
<< "], loop:\n"
<< sch->GetModule().GetExprs().front() << "\n vectorize expr:]\n"
<< vectorize_expr;

// Get all the lter values in the axis bind that contain a loop var and the
// corresponding iter var.
auto get_bound_variables = [&vectorize_expr](const Expr& expr,
const Expr& loop_var) {
BoundVariableMap bound_variable_map;
auto loop_var_name = loop_var.as_var()->name;

ir::ir_utils::CollectIRNodes(
expr,
[&loop_var_name, &bound_variable_map](const Expr* x) {
if (const auto block_realize = x->As<ir::ScheduleBlockRealize>()) {
auto* schedule_block =
block_realize->schedule_block.As<ScheduleBlock>();
auto iter_values = block_realize->iter_values;
auto iter_vars = schedule_block->iter_vars;

for (std::size_t i = 0; i < iter_values.size(); ++i) {
const auto& iter_value = iter_values[i];
if (iter_value.is_var() &&
iter_value.as_var()->name.find(loop_var_name) !=
std::string::npos) {
bound_variable_map[loop_var_name].emplace_back(iter_vars[i]);
} else if (iter_value.is_index()) {
ir::ir_utils::CollectIRNodes(
iter_value,
[&loop_var_name, &bound_variable_map, &iter_vars, i](
const Expr* x) {
if (const auto* var = x->As<ir::_Var_>()) {
if (var->name == loop_var_name) {
bound_variable_map[loop_var_name].emplace_back(
iter_vars[i]);
}
}
return false;
});
}
}
}
return false;
},
true);

return bound_variable_map;
};

auto bound_variable_map = get_bound_variables(
vectorize_expr, vectorize_expr.As<ir::For>()->loop_var);

VLOG(5) << "bound_variable_map:\n";
for (const auto& entry : bound_variable_map) {
const auto& loop_var_name = entry.first;
const auto& vars = entry.second;
VLOG(5) << "Loop Variable: " << loop_var_name;
for (const auto& var : vars) {
VLOG(5) << "Var: " << var;
}
}

if (bound_variable_map.empty()) {
return false;
}

return true;
}

bool ScheduleBlockEnableVectorize(const ScheduleConfig& config,
const std::string& block_id) {
if (!config.base_info->enable_vectorize) return false;
Expand Down Expand Up @@ -555,18 +649,28 @@ void TileFirstGeneralTactic::ApplyVectorize(ir::IRSchedule* sch,
};

auto loops = sch->GetLoops(block_id);
if (sp_loop > 1) {
sch->Split(loops[0],
std::vector<int>{-1, sp_loop, sp_thread, vectorize_factor});
// The iter_value bound by axis_bind must contain the loop_var of the axis
// to be vectorized.
if (ContainsVectorizableAxis(sch, loops.size() - 1, block_id)) {
if (sp_loop > 1) {
sch->Split(loops[0],
std::vector<int>{-1, sp_loop, sp_thread, vectorize_factor});
} else {
sch->Split(loops[0], std::vector<int>{-1, sp_thread, vectorize_factor});
}

// set vectorize schedule primitives
loops = sch->GetLoops(block_id);
auto vectorize_axis = loops.size() - 1;
sch->Vectorize(loops[vectorize_axis], vectorize_factor);
} else {
sch->Split(loops[0], std::vector<int>{-1, sp_thread, vectorize_factor});
if (sp_loop > 1) {
sch->Split(loops[0], std::vector<int>{-1, sp_loop, sp_thread});
} else {
sch->Split(loops[0], std::vector<int>{-1, sp_thread});
}
}

// set vectorize schedule primitives
loops = sch->GetLoops(block_id);
auto vectorize_axis = loops.size() - 1;
sch->Vectorize(loops[vectorize_axis], vectorize_factor);

loops = sch->GetLoops(block_id);
DoBind(loops);
return;
Expand All @@ -576,20 +680,34 @@ void TileFirstGeneralTactic::ApplyVectorize(ir::IRSchedule* sch,
// only deal with spatial block and don't support blockIdx.y
if (!IsReductionSBlock(sch->GetBlock(block_id))) {
auto loops = sch->GetLoops(block_id);
sch->Split(loops[1], std::vector<int>{-1, rd_thread, vectorize_factor});
// The iter_value bound by axis_bind must contain the loop_var of the axis
// to be vectorized.
if (ContainsVectorizableAxis(sch, loops.size() - 1, block_id)) {
sch->Split(loops[1], std::vector<int>{-1, rd_thread, vectorize_factor});

// set vectorize schedule primitives
loops = sch->GetLoops(block_id);
auto vectorize_axis = loops.size() - 1;
sch->Vectorize(loops[vectorize_axis], vectorize_factor);
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
sch->Bind(loops[0], "blockIdx.x");
auto threadsIdx_x_axis = vectorize_axis - 1;
sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x");
};
loops = sch->GetLoops(block_id);
DoBind(loops);
return;
// set vectorize schedule primitives
loops = sch->GetLoops(block_id);
auto vectorize_axis = loops.size() - 1;
sch->Vectorize(loops[vectorize_axis], vectorize_factor);
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
sch->Bind(loops[0], "blockIdx.x");
auto threadsIdx_x_axis = vectorize_axis - 1;
sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x");
};
loops = sch->GetLoops(block_id);
DoBind(loops);
return;
} else {
sch->Split(loops[1], std::vector<int>{-1, rd_thread});
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
sch->Bind(loops[0], "blockIdx.x");
auto threadsIdx_x_axis = loops.size() - 1;
sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x");
};
loops = sch->GetLoops(block_id);
DoBind(loops);
return;
}
}
return;
}
Expand Down

0 comments on commit 45cce6d

Please sign in to comment.