diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc index f9cdb3f10c9407..82e6fcc7fe54fd 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -45,6 +45,100 @@ bool UseContinuousDataTile(const ScheduleConfig& config) { return false; } +using BoundVariableMap = std::unordered_map>; + +/* + * Check if the current loop variable containing the vectorize axis + * is present in the iter values of the axis bind within the loop body. + * If it is present, the loop cannot be vectorized. + * For example, the following loop cannot be vectorized: + * + * serial for (append_var_0_append_var_1_append_var_2_fused, 0ll, 2304ll) + * { + * ScheduleBlock(var_30) { + * i0_29 = axis.bind(append_var_3_i_append_var_4_append_var_5_fused) + * if (((append_var_3_i_append_var_4_append_var_5_fused / 32ll) == 0ll)) + * var_30[i0_29] = ((0.899999976f * var_17[i0_29]) + (0.100000024f * + * var_2[0ll, i0_29, 0ll, 0ll])) + * } + * } + */ +bool ContainsVectorizableAxis(const ir::IRSchedule* sch, + const size_t vectorize_axis, + const std::string& block_id) { + auto loops = sch->GetLoops(block_id); + auto vectorize_expr = loops[vectorize_axis]; + + VLOG(4) << "Checking ContainsVectorizableAxis on block: [" << block_id + << "], loop:\n" + << sch->GetModule().GetExprs().front() << "\n vectorize expr:]\n" + << vectorize_expr; + + // Get all the lter values in the axis bind that contain a loop var and the + // corresponding iter var. + auto get_bound_variables = [&vectorize_expr](const Expr& expr, + const Expr& loop_var) { + BoundVariableMap bound_variable_map; + auto loop_var_name = loop_var.as_var()->name; + + ir::ir_utils::CollectIRNodes( + expr, + [&loop_var_name, &bound_variable_map](const Expr* x) { + if (const auto block_realize = x->As()) { + auto* schedule_block = + block_realize->schedule_block.As(); + auto iter_values = block_realize->iter_values; + auto iter_vars = schedule_block->iter_vars; + + for (std::size_t i = 0; i < iter_values.size(); ++i) { + const auto& iter_value = iter_values[i]; + if (iter_value.is_var() && + iter_value.as_var()->name.find(loop_var_name) != + std::string::npos) { + bound_variable_map[loop_var_name].emplace_back(iter_vars[i]); + } else if (iter_value.is_index()) { + ir::ir_utils::CollectIRNodes( + iter_value, + [&loop_var_name, &bound_variable_map, &iter_vars, i]( + const Expr* x) { + if (const auto* var = x->As()) { + if (var->name == loop_var_name) { + bound_variable_map[loop_var_name].emplace_back( + iter_vars[i]); + } + } + return false; + }); + } + } + } + return false; + }, + true); + + return bound_variable_map; + }; + + auto bound_variable_map = get_bound_variables( + vectorize_expr, vectorize_expr.As()->loop_var); + + VLOG(5) << "bound_variable_map:\n"; + for (const auto& entry : bound_variable_map) { + const auto& loop_var_name = entry.first; + const auto& vars = entry.second; + VLOG(5) << "Loop Variable: " << loop_var_name; + for (const auto& var : vars) { + VLOG(5) << "Var: " << var; + } + } + + if (bound_variable_map.empty()) { + return false; + } + + return true; +} + bool ScheduleBlockEnableVectorize(const ScheduleConfig& config, const std::string& block_id) { if (!config.base_info->enable_vectorize) return false; @@ -555,18 +649,28 @@ void TileFirstGeneralTactic::ApplyVectorize(ir::IRSchedule* sch, }; auto loops = sch->GetLoops(block_id); - if (sp_loop > 1) { - sch->Split(loops[0], - std::vector{-1, sp_loop, sp_thread, vectorize_factor}); + // The iter_value bound by axis_bind must contain the loop_var of the axis + // to be vectorized. + if (ContainsVectorizableAxis(sch, loops.size() - 1, block_id)) { + if (sp_loop > 1) { + sch->Split(loops[0], + std::vector{-1, sp_loop, sp_thread, vectorize_factor}); + } else { + sch->Split(loops[0], std::vector{-1, sp_thread, vectorize_factor}); + } + + // set vectorize schedule primitives + loops = sch->GetLoops(block_id); + auto vectorize_axis = loops.size() - 1; + sch->Vectorize(loops[vectorize_axis], vectorize_factor); } else { - sch->Split(loops[0], std::vector{-1, sp_thread, vectorize_factor}); + if (sp_loop > 1) { + sch->Split(loops[0], std::vector{-1, sp_loop, sp_thread}); + } else { + sch->Split(loops[0], std::vector{-1, sp_thread}); + } } - // set vectorize schedule primitives - loops = sch->GetLoops(block_id); - auto vectorize_axis = loops.size() - 1; - sch->Vectorize(loops[vectorize_axis], vectorize_factor); - loops = sch->GetLoops(block_id); DoBind(loops); return; @@ -576,20 +680,34 @@ void TileFirstGeneralTactic::ApplyVectorize(ir::IRSchedule* sch, // only deal with spatial block and don't support blockIdx.y if (!IsReductionSBlock(sch->GetBlock(block_id))) { auto loops = sch->GetLoops(block_id); - sch->Split(loops[1], std::vector{-1, rd_thread, vectorize_factor}); + // The iter_value bound by axis_bind must contain the loop_var of the axis + // to be vectorized. + if (ContainsVectorizableAxis(sch, loops.size() - 1, block_id)) { + sch->Split(loops[1], std::vector{-1, rd_thread, vectorize_factor}); - // set vectorize schedule primitives - loops = sch->GetLoops(block_id); - auto vectorize_axis = loops.size() - 1; - sch->Vectorize(loops[vectorize_axis], vectorize_factor); - const auto DoBind = [&](const std::vector& loops) { - sch->Bind(loops[0], "blockIdx.x"); - auto threadsIdx_x_axis = vectorize_axis - 1; - sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x"); - }; - loops = sch->GetLoops(block_id); - DoBind(loops); - return; + // set vectorize schedule primitives + loops = sch->GetLoops(block_id); + auto vectorize_axis = loops.size() - 1; + sch->Vectorize(loops[vectorize_axis], vectorize_factor); + const auto DoBind = [&](const std::vector& loops) { + sch->Bind(loops[0], "blockIdx.x"); + auto threadsIdx_x_axis = vectorize_axis - 1; + sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x"); + }; + loops = sch->GetLoops(block_id); + DoBind(loops); + return; + } else { + sch->Split(loops[1], std::vector{-1, rd_thread}); + const auto DoBind = [&](const std::vector& loops) { + sch->Bind(loops[0], "blockIdx.x"); + auto threadsIdx_x_axis = loops.size() - 1; + sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x"); + }; + loops = sch->GetLoops(block_id); + DoBind(loops); + return; + } } return; }