From eafe635ec918264c395c17ceb537d3b7937747ab Mon Sep 17 00:00:00 2001 From: ZhangX-21 Date: Fri, 13 Dec 2024 11:10:34 +0000 Subject: [PATCH] [CINN] support FusionGroup blocks vectorize check --- .../hlir/framework/pir/trivial_op_impl.cc | 162 ++++++++++++++++++ .../cinn/hlir/framework/pir/trivial_op_impl.h | 4 +- .../config/group_tile_config.cc | 13 +- .../group_schedule/config/group_tile_config.h | 2 +- .../tactic/compute_inline_tactic.cc | 3 - .../tactic/tile_first_general_tactic.cc | 4 +- paddle/cinn/optim/vectorize_for_trans.cc | 14 +- 7 files changed, 187 insertions(+), 15 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc index 88282c560f5a0..e4e45d4c158e1 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -26,6 +26,7 @@ #include "paddle/cinn/ir/dim.h" #include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" #include "paddle/cinn/ir/group_schedule/config/group_tile_util.h" +#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule/ir_schedule_util.h" #include "paddle/cinn/lang/placeholder.h" @@ -759,6 +760,165 @@ std::vector GetLoopStrides(const ir::Expr& body, return loop_strides; } +bool GetCanApplyVectorize(const std::vector& op_compute_bodies) { + bool can_vectorize = true; + for (const auto& body : op_compute_bodies) { + ir::Expr expr_schedule_block_realize = + trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes + .GetSingle(body); + bool is_reduce = + ir::analyzer::IsReductionSBlock(expr_schedule_block_realize); + if (is_reduce) continue; + std::vector iter_values = + expr_schedule_block_realize.As()->iter_values; + const std::vector for_iters = + trivial_fusion_detail::GetAllForIters(body); + std::unordered_map iter_var2value = + ir::analyzer::GetIterVarToValueOfSBlock(expr_schedule_block_realize); + std::unordered_map>> + load_tensors_index; + ir::ir_utils::CollectIRNodesWithoutTensor( + expr_schedule_block_realize, + [&](const ir::Expr* expr) { + if (expr->As()) { + auto* node = expr->As(); + PADDLE_ENFORCE_NOT_NULL( + node, + ::common::errors::InvalidArgument( + "Expected Load node, but received nullptr.")); + auto* tensor = node->tensor.As(); + PADDLE_ENFORCE_NOT_NULL( + tensor, + ::common::errors::InvalidArgument( + "Expected _Tensor_ node in load, but received nullptr.")); + load_tensors_index[tensor->name].push_back(node->indices); + return true; + } + return false; + }, + /* uniq_target = */ false); + + std::unordered_map>> + store_tensors_index; + ir::ir_utils::CollectIRNodesWithoutTensor( + expr_schedule_block_realize, + [&](const ir::Expr* expr) { + if (expr->As()) { + auto* node = expr->As(); + PADDLE_ENFORCE_NOT_NULL( + node, + ::common::errors::InvalidArgument( + "Expected Load node, but received nullptr.")); + auto* tensor = node->tensor.As(); + PADDLE_ENFORCE_NOT_NULL( + tensor, + ::common::errors::InvalidArgument( + "Expected _Tensor_ node in load, but received nullptr.")); + store_tensors_index[tensor->name].push_back(node->indices); + return true; + } + return false; + }, + /* uniq_target = */ false); + + auto CheckTensorIsBroadcastAndContinuous = + [&](const std::vector& indices) { + int loop_idx = 0; + bool is_broadcast = false; + for (int i = 0; i < indices.size(); ++i) { + const ir::Expr& index = indices[i]; + if (index.is_constant()) { + is_broadcast = true; + continue; + } + + if (!index.is_var()) return false; + ir::Var iter_var = index.as_var_ref(); + if (!iter_var2value.count(iter_var)) { + return false; + } + ir::Expr iter_value = iter_var2value.at(iter_var); + PADDLE_ENFORCE_EQ( + iter_value.as_var() || iter_value.is_constant(), + true, + ::common::errors::PreconditionNotMet( + "Required iter_value shall be var or constant type.")); + for (; loop_idx < for_iters.size(); ++loop_idx) { + if (for_iters[loop_idx] == iter_value.as_var_ref()) { + break; + } + } + + if (loop_idx == for_iters.size()) { + return false; + } + } + if (is_broadcast) return true; + return false; + }; + + auto CheckoutTensorIsContinuous = [&](const std::vector& indices) { + for (int i = 0; i < indices.size(); ++i) { + const ir::Expr& index = indices[i]; + if (index.is_constant()) return false; + if (!index.is_var()) return false; + ir::Var iter_var = index.as_var_ref(); + if (!iter_var2value.count(iter_var)) { + return false; + } + ir::Expr iter_value = iter_var2value.at(iter_var); + PADDLE_ENFORCE_EQ( + iter_value.as_var() || iter_value.is_constant(), + true, + ::common::errors::PreconditionNotMet( + "Required iter_value shall be var or constant type.")); + if (for_iters[i] != iter_value.as_var_ref()) { + return false; + } + } + return true; + }; + + // load tensor information + std::unordered_set is_broadcast_continuous_tensors; + std::unordered_set is_continuous_tensors; + // bool can_vectorize = true; + for (const auto& tensor_index : load_tensors_index) { + for (auto indexs : tensor_index.second) { + if (CheckTensorIsBroadcastAndContinuous(indexs)) { + is_broadcast_continuous_tensors.insert(tensor_index.first); + continue; + } + if (CheckoutTensorIsContinuous(indexs)) { + is_continuous_tensors.insert(tensor_index.first); + continue; + } + can_vectorize = false; + break; + } + } + // store tensor information + for (const auto& tensor_index : store_tensors_index) { + for (auto indexs : tensor_index.second) { + if (CheckTensorIsBroadcastAndContinuous(indexs)) { + is_broadcast_continuous_tensors.insert(tensor_index.first); + continue; + } + + if (CheckoutTensorIsContinuous(indexs)) { + is_continuous_tensors.insert(tensor_index.first); + continue; + } + can_vectorize = false; + break; + } + } + if (!can_vectorize) break; + } + + return can_vectorize; +} + std::shared_ptr GetFusionGroupInfo( const std::vector& op_compute_bodies) { using trivial_fusion_detail::AppendBound; @@ -841,6 +1001,8 @@ std::shared_ptr GetFusionGroupInfo( GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis); } + group_info->can_apply_vectorize = GetCanApplyVectorize(op_compute_bodies); + VLOG(4) << group_info->DebugPrint(); return group_info; } diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.h b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h index 6be43e702adaa..31831234139a3 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.h +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h @@ -166,6 +166,7 @@ struct FusionGroupInfo { std::vector reduce_axis; std::vector reduce_var_name; bool can_apply_grid_reduce; + bool can_apply_vectorize; std::string DebugPrint() { std::stringstream ss; @@ -173,7 +174,8 @@ struct FusionGroupInfo { << "\nloop_strides: " << cinn::utils::Join(loop_strides, ", ") << "\nreduce_axis: " << cinn::utils::Join(reduce_axis, " ") << "\nreduce_var_name: " << cinn::utils::Join(reduce_var_name, " ") - << "\ncan_apply_grid_reduce: " << can_apply_grid_reduce; + << "\ncan_apply_grid_reduce: " << can_apply_grid_reduce + << "\ncan_apply_vectorize: " << can_apply_vectorize; return ss.str(); } }; diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc index 03f9acc8720bb..e45d254571f1d 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc @@ -153,6 +153,7 @@ std::shared_ptr InitBasicInfo( base_info->data_rank = group_info->loop_ranges.size(); base_info->loop_strides = group_info->loop_strides; base_info->can_apply_grid_reduce = group_info->can_apply_grid_reduce; + base_info->can_apply_vectorize = group_info->can_apply_vectorize; std::set reduce_dim_loc; for (int64_t dim : group_info->reduce_axis) { @@ -184,6 +185,7 @@ std::shared_ptr InitBasicInfo( TileConfigMap BuildVectorizeConfig( const std::shared_ptr& base_info, const common::Target& target) { + if (!base_info->can_apply_vectorize) return {}; // current only support [S, R] and [S] const int iters_dim = base_info->iter_space_type.size(); if (iters_dim > 2) return {}; @@ -194,12 +196,12 @@ TileConfigMap BuildVectorizeConfig( ReduceMethod reduce_method = NoneReduceMethod(); int vectorize_factor = 1; const std::vector vectorize_factors{4, 2}; - + bool can_vectorize = false; auto const CheckVectorize = [&](int nums, int threads, int factor) { const int deal_elements_in_warp = threads * factor; if (nums % deal_elements_in_warp == 0) { vectorize_factor = factor; - base_info->enable_vectorize = true; + can_vectorize = true; return true; } return false; @@ -235,9 +237,10 @@ TileConfigMap BuildVectorizeConfig( } } } - - if (!base_info->enable_vectorize) return {}; - + if (!can_vectorize) { + base_info->can_apply_vectorize = false; + return {}; + } int64_t sp_inner_num = [&]() -> int64_t { if (rd_thread_num > 1) return 1; spatial_numel = spatial_numel / sp_thread_num / vectorize_factor; diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.h b/paddle/cinn/ir/group_schedule/config/group_tile_config.h index 6e4c10648de93..288c0d3d68fe7 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.h +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.h @@ -40,7 +40,7 @@ struct ScheduleConfig { bool has_dynamic_spatial{false}; bool has_dynamic_reduce{false}; bool can_apply_grid_reduce{false}; - bool enable_vectorize{false}; + bool can_apply_vectorize{false}; IterSpaceType iter_space_type; }; diff --git a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc index 049d438cd2db8..d7dbd660c5aec 100644 --- a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc @@ -36,13 +36,11 @@ class ComputeInlineTactic final : public ScheduleTactic { private: std::unordered_set output_names_; cinn::common::Target target_; - bool enable_vectorize_{false}; }; void ComputeInlineTactic::Init(ScheduleContext* context) { output_names_ = context->output_names; target_ = context->target; - enable_vectorize_ = context->config.base_info->enable_vectorize; } void ComputeInlineTactic::Apply(ir::IRSchedule* sch, @@ -53,7 +51,6 @@ void ComputeInlineTactic::Apply(ir::IRSchedule* sch, // return; // } // compute inline tactic not work, when apply vectorize in current schedule - if (enable_vectorize_) return; auto_schedule::AutoInline inliner(target_, output_names_); VLOG(6) << "try ComputeInline on: " << block_id << ", before ComputeInline, func body: " diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc index 32ad05dbe9573..17cec48d39767 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -138,11 +138,11 @@ bool ContainsVectorizableAxis(const ir::IRSchedule* sch, bool ScheduleBlockEnableVectorize(const ScheduleConfig& config, const std::string& block_id) { - if (!config.base_info->enable_vectorize) return false; + if (!config.base_info->can_apply_vectorize) return false; // currently, dont't support vectorize for SplitedTransformTensor // ScheduleBlock - if (ir::IsSplitTransformTensorName(block_id)) return false; + // if (ir::IsSplitTransformTensorName(block_id)) return false; if (!UseContinuousDataTile(config)) return false; diff --git a/paddle/cinn/optim/vectorize_for_trans.cc b/paddle/cinn/optim/vectorize_for_trans.cc index 1c16c81918661..16f1756fd6dd3 100644 --- a/paddle/cinn/optim/vectorize_for_trans.cc +++ b/paddle/cinn/optim/vectorize_for_trans.cc @@ -225,15 +225,23 @@ class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator { return false; } - cinn::optim::Simplify(&offset); - Expr origin_offset = ir::ir_utils::IRCopy(offset); + // only with vectorize axis offset + auto only_vectorize_axis_offset = ir::ir_utils::IRCopy(offset); + for (const auto &[key, value] : var_symbols) { + if (key == iter_var_->name) continue; + cinn::ir::ir_utils::IrReplaceVarBroadcast( + &only_vectorize_axis_offset, Expr(value), Expr(int32_t(0))); + } + + cinn::optim::Simplify(&only_vectorize_axis_offset); + Expr origin_offset = ir::ir_utils::IRCopy(only_vectorize_axis_offset); cinn::ir::ir_utils::IrReplaceVarBroadcast( &origin_offset, Expr(iter_var_), Expr(int32_t(0))); cinn::optim::Simplify(&origin_offset); bool is_zero = true; bool is_continous = true; for (int i = 1; i < factor_; i++) { - Expr next = ir::ir_utils::IRCopy(offset); + Expr next = ir::ir_utils::IRCopy(only_vectorize_axis_offset); cinn::ir::ir_utils::IrReplaceVarBroadcast( &next, Expr(iter_var_), Expr(int32_t(i))); cinn::optim::Simplify(&next);