diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc index a546d03d344c23..03f9acc8720bb4 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc @@ -23,7 +23,7 @@ using TileConfigMap = std::unordered_map; namespace { - +const int kWarpSize = 32; const int kMaxNumel = INT32_MAX; int64_t CeilPow2(int64_t n) { @@ -181,6 +181,83 @@ std::shared_ptr InitBasicInfo( return base_info; } +TileConfigMap BuildVectorizeConfig( + const std::shared_ptr& base_info, + const common::Target& target) { + // current only support [S, R] and [S] + const int iters_dim = base_info->iter_space_type.size(); + if (iters_dim > 2) return {}; + const auto& last_dim = base_info->iter_space_type.back().first; + const int sm_count = target.get_multi_processor_count(); + int64_t spatial_numel = base_info->spatial_numel; + int64_t reduce_numel = base_info->reduce_numel; + ReduceMethod reduce_method = NoneReduceMethod(); + int vectorize_factor = 1; + const std::vector vectorize_factors{4, 2}; + + auto const CheckVectorize = [&](int nums, int threads, int factor) { + const int deal_elements_in_warp = threads * factor; + if (nums % deal_elements_in_warp == 0) { + vectorize_factor = factor; + base_info->enable_vectorize = true; + return true; + } + return false; + }; + + int64_t sp_thread_num = 1; + int64_t rd_thread_num = 1; + int64_t warp_nums = 1; + // Reduce Region + if (last_dim == "R") { + for (auto factor : vectorize_factors) { + vectorize_factor = factor; + const int elements_in_warp = kWarpSize * vectorize_factor; + warp_nums = CeilDiv(reduce_numel, elements_in_warp); + warp_nums = Trim(warp_nums, 1, 8); + if (warp_nums > 1 || spatial_numel < warp_nums * 64) { + rd_thread_num = warp_nums * kWarpSize; + if (CheckVectorize(reduce_numel, rd_thread_num, vectorize_factor)) { + break; + } + reduce_method = BlockReduceMethod(); + } + } + } else if (iters_dim == 1 && last_dim == "S") { // Spatial Region + for (auto factor : vectorize_factors) { + vectorize_factor = factor; + const int elements_in_warp = kWarpSize * vectorize_factor; + warp_nums = CeilDiv(spatial_numel, elements_in_warp); + warp_nums = Trim(warp_nums, 1, 16); + sp_thread_num = kWarpSize * warp_nums; + if (CheckVectorize(spatial_numel, sp_thread_num, vectorize_factor)) { + break; + } + } + } + + if (!base_info->enable_vectorize) return {}; + + int64_t sp_inner_num = [&]() -> int64_t { + if (rd_thread_num > 1) return 1; + spatial_numel = spatial_numel / sp_thread_num / vectorize_factor; + int64_t expected = spatial_numel / (sm_count * 16); + return Trim(expected, 1, 4); + }(); + + int64_t sp_upper_bound = base_info->spatial_numel > 1 ? kMaxNumel : 1; + int64_t rd_upper_bound = base_info->reduce_numel > 1 ? kMaxNumel : 1; + BucketInfo bucket_info{1, sp_upper_bound, 1, rd_upper_bound}; + warp_nums = Trim(sp_thread_num * rd_thread_num / kWarpSize, 1, 32); + TileConfig tile_config{warp_nums, + /* tree_reduce_num = */ rd_thread_num, + /* grid_reduce_num = */ 1, + /* spatial_inner_num = */ sp_inner_num, + /* vectorize_factor = */ vectorize_factor, + reduce_method}; + return {{bucket_info, tile_config}}; +} + TileConfigMap BuildPureStaticShapeConfig( const std::shared_ptr& base_info, const common::Target& target) { @@ -190,6 +267,10 @@ TileConfigMap BuildPureStaticShapeConfig( int64_t reduce_numel = base_info->reduce_numel; ReduceMethod reduce_method = NoneReduceMethod(); + // Try to use vectorization first + auto config_map = BuildVectorizeConfig(base_info, target); + if (!config_map.empty()) return std::move(config_map); + // 1. Allocate spatial/reduce threads // Principals: // 1) The low 32 threads are assigned to the last dimension to ensure @@ -259,6 +340,7 @@ TileConfigMap BuildPureStaticShapeConfig( /* tree_reduce_num = */ rd_thread_num, /* grid_reduce_num = */ rd_block_num, /* spatial_inner_num = */ sp_inner_num, + /* vectorize_factor = */ 1, reduce_method}; return {{bucket_info, tile_config}}; } @@ -277,6 +359,7 @@ TileConfigMap BuildStaticSpatialConfig( /* tree_reduce_num = */ 256, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, BlockReduceMethod()}; return {{bucket_info, tile_config}}; } else { @@ -290,6 +373,7 @@ TileConfigMap BuildStaticSpatialConfig( /* tree_reduce_num = */ 32, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, WarpReduceMethod()}; BucketInfo bucket_info_257_2048{/* sp_lower_bound = */ 1, @@ -302,6 +386,7 @@ TileConfigMap BuildStaticSpatialConfig( /* tree_reduce_num = */ 128, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, BlockReduceMethod()}; BucketInfo bucket_info_2049_INF{/* sp_lower_bound = */ 1, @@ -314,6 +399,7 @@ TileConfigMap BuildStaticSpatialConfig( /* tree_reduce_num = */ 256, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, BlockReduceMethod()}; return {{bucket_info_1_256, tile_config_1_256}, @@ -336,6 +422,7 @@ TileConfigMap BuildStaticReduceConfig( /* tree_reduce_num = */ 1, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, NoneReduceMethod()}; BucketInfo bucket_info__1024_1M{/* sp_lower_bound = */ 1024, /* sp_upper_bound = */ 1024 * 1024 - 1, @@ -347,6 +434,7 @@ TileConfigMap BuildStaticReduceConfig( /* tree_reduce_num = */ 1, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 4, + /* vectorize_factor = */ 1, NoneReduceMethod()}; BucketInfo bucket_info__1M_INF{/* sp_lower_bound = */ 1024 * 1024, /* sp_upper_bound = */ kMaxNumel, @@ -358,6 +446,7 @@ TileConfigMap BuildStaticReduceConfig( /* tree_reduce_num = */ 1, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 4, + /* vectorize_factor = */ 1, NoneReduceMethod()}; return {{bucket_info__1_1023, tile_config__1_1023}, {bucket_info__1024_1M, tile_config__1024_1M}, @@ -374,6 +463,7 @@ TileConfigMap BuildStaticReduceConfig( /* tree_reduce_num = */ 32, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ (256 / CeilPow2(base_info->reduce_numel)), + /* vectorize_factor = */ 1, WarpReduceMethod()}; return {{bucket_info, tile_config}}; } else if (base_info->reduce_numel <= 2048) { @@ -392,6 +482,7 @@ TileConfigMap BuildStaticReduceConfig( tree_reduce_num, /* grid_reduce_num = */ 1, /* spatial_inner_num */ 1, + /* vectorize_factor = */ 1, BlockReduceMethod()}; return {{bucket_info, tile_config}}; } else { @@ -405,6 +496,7 @@ TileConfigMap BuildStaticReduceConfig( /* tree_reduce_num = */ 1024, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, BlockReduceMethod()}; return {{bucket_info, tile_config}}; } @@ -423,6 +515,7 @@ TileConfigMap BuildDynamicShapeConfig( /* tree_reduce_num = */ 1024, /* grid_reduce_num = */ 1, /* spatial_inner_num = */ 1, + /* vectorize_factor = */ 1, BlockReduceMethod()}; return {{bucket_info, tile_config}}; } diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.h b/paddle/cinn/ir/group_schedule/config/group_tile_config.h index d0ac0e1107bdca..6e4c10648de939 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.h +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.h @@ -40,6 +40,7 @@ struct ScheduleConfig { bool has_dynamic_spatial{false}; bool has_dynamic_reduce{false}; bool can_apply_grid_reduce{false}; + bool enable_vectorize{false}; IterSpaceType iter_space_type; }; @@ -48,6 +49,7 @@ struct ScheduleConfig { int64_t tree_reduce_num{1}; int64_t grid_reduce_num{1}; int64_t spatial_inner_num{1}; + int64_t vectorize_factor{1}; ReduceMethod reduce_method{NoneReduceMethod()}; }; diff --git a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc index 5076d1ded1e69f..049d438cd2db80 100644 --- a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc @@ -36,11 +36,13 @@ class ComputeInlineTactic final : public ScheduleTactic { private: std::unordered_set output_names_; cinn::common::Target target_; + bool enable_vectorize_{false}; }; void ComputeInlineTactic::Init(ScheduleContext* context) { output_names_ = context->output_names; target_ = context->target; + enable_vectorize_ = context->config.base_info->enable_vectorize; } void ComputeInlineTactic::Apply(ir::IRSchedule* sch, @@ -50,6 +52,8 @@ void ComputeInlineTactic::Apply(ir::IRSchedule* sch, // if (IsProhibitScheduleExternCallBlock(node->Block())) { // return; // } + // compute inline tactic not work, when apply vectorize in current schedule + if (enable_vectorize_) return; auto_schedule::AutoInline inliner(target_, output_names_); VLOG(6) << "try ComputeInline on: " << block_id << ", before ComputeInline, func body: " diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc index 1ee6b50fabc0ae..f9cdb3f10c9407 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -25,6 +25,14 @@ namespace ir { using cinn::ir::analyzer::IsReductionSBlock; +bool IsSpatialRegion(const ScheduleConfig& config) { + if (config.base_info->iter_space_type.size() == 1 && + config.base_info->iter_space_type.back().first == "S") { + return true; + } + return false; +} + bool UseContinuousDataTile(const ScheduleConfig& config) { // use continuous data tile for [S] and [...R] if (config.base_info->iter_space_type.size() == 1 && @@ -37,6 +45,21 @@ bool UseContinuousDataTile(const ScheduleConfig& config) { return false; } +bool ScheduleBlockEnableVectorize(const ScheduleConfig& config, + const std::string& block_id) { + if (!config.base_info->enable_vectorize) return false; + + // currently, dont't support vectorize for SplitedTransformTensor + // ScheduleBlock + if (ir::IsSplitTransformTensorName(block_id)) return false; + + if (!UseContinuousDataTile(config)) return false; + + // TODO(ZhangX): check tensor indexs contains vectorize axis + + return true; +} + class TileFirstGeneralTactic final : public ScheduleTactic { public: void Init(ScheduleContext* context) override; @@ -44,6 +67,7 @@ class TileFirstGeneralTactic final : public ScheduleTactic { void Apply(ir::IRSchedule* sch, const std::string& block_id) override; void ApplyContinuousDataTile(ir::IRSchedule* sch, const std::string& block_id); + void ApplyVectorize(ir::IRSchedule* sch, const std::string& block_id); std::string TacticName() const override { return "TileFirstGeneralTactic"; } @@ -112,6 +136,12 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { if (ir::IsReduceInitTensorName(block_id)) return; + // loops tiling with vectorize + if (ScheduleBlockEnableVectorize(context_->config, block_id)) { + ApplyVectorize(sch, block_id); + return; + } + AlignToReduceInput(sch, block_id); VLOG(6) << "After AlignToReduceInput on block: [" << block_id << "], loop nest:\n" @@ -488,6 +518,82 @@ void TileFirstGeneralTactic::BindCudaInfo(ir::IRSchedule* sch, } } +void TileFirstGeneralTactic::ApplyVectorize(ir::IRSchedule* sch, + const std::string& block_id) { + const auto sp_thread = context_->config.tile_config.warp_num * 32 / + context_->config.tile_config.tree_reduce_num; + const auto sp_loop = context_->config.tile_config.spatial_inner_num; + const auto vectorize_factor = context_->config.tile_config.vectorize_factor; + const auto rd_thread = context_->config.tile_config.tree_reduce_num; + const auto rd_block = context_->config.tile_config.grid_reduce_num; + VLOG(4) << "ApplyContinuousDataTile sp_thread=" << sp_thread; + VLOG(4) << "ApplyContinuousDataTile sp_loop=" << sp_loop; + VLOG(4) << "ApplyContinuousDataTile rd_thread=" << rd_thread; + VLOG(4) << "ApplyContinuousDataTile rd_block=" << rd_block; + // Merge reduce axes + MergeReduceAxis(sch, block_id); + VLOG(4) << "After MergeReduceAxis on block: [" << block_id + << "], loop nest:\n" + << sch->GetModule().GetExprs().front(); + + // Merge spatial axes + MergeFlattenAxis(sch, block_id); + VLOG(4) << "After MergeFlattenAxis on block: [" << block_id + << "], loop nest:\n" + << sch->GetModule().GetExprs().front(); + + // Spatial situation + // deal with spation block + if (IsSpatialRegion(context_->config)) { + const auto DoBind = [&](const std::vector& loops) { + sch->Bind(loops[0], "blockIdx.x"); + if (sp_loop > 1) { + sch->Bind(loops[2], "threadIdx.x"); + } else { + sch->Bind(loops[1], "threadIdx.x"); + } + }; + + auto loops = sch->GetLoops(block_id); + if (sp_loop > 1) { + sch->Split(loops[0], + std::vector{-1, sp_loop, sp_thread, vectorize_factor}); + } else { + sch->Split(loops[0], std::vector{-1, sp_thread, vectorize_factor}); + } + + // set vectorize schedule primitives + loops = sch->GetLoops(block_id); + auto vectorize_axis = loops.size() - 1; + sch->Vectorize(loops[vectorize_axis], vectorize_factor); + + loops = sch->GetLoops(block_id); + DoBind(loops); + return; + } + + // Reduce situation + // only deal with spatial block and don't support blockIdx.y + if (!IsReductionSBlock(sch->GetBlock(block_id))) { + auto loops = sch->GetLoops(block_id); + sch->Split(loops[1], std::vector{-1, rd_thread, vectorize_factor}); + + // set vectorize schedule primitives + loops = sch->GetLoops(block_id); + auto vectorize_axis = loops.size() - 1; + sch->Vectorize(loops[vectorize_axis], vectorize_factor); + const auto DoBind = [&](const std::vector& loops) { + sch->Bind(loops[0], "blockIdx.x"); + auto threadsIdx_x_axis = vectorize_axis - 1; + sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x"); + }; + loops = sch->GetLoops(block_id); + DoBind(loops); + return; + } + return; +} + std::unique_ptr CreateTileFirstGeneralTactic() { return std::make_unique(); } diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 1f64709435ce28..d81475c6f66583 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -687,6 +687,10 @@ bool IsReduceInitTensorName(const std::string &tensor_name) { reduce_init_suffix.size()) == reduce_init_suffix; } +bool IsSplitTransformTensorName(const std::string &tensor_name) { + return tensor_name.find("_split_transform") != std::string::npos; +} + std::string GetOriginalReduceTensorName(const std::string &tensor_name) { std::string reduce_init_suffix(kReduceInitSuffix); if (IsReduceInitTensorName(tensor_name)) { diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h index d154b8c3bb57a2..18f1abb8f2e286 100644 --- a/paddle/cinn/ir/tensor.h +++ b/paddle/cinn/ir/tensor.h @@ -113,6 +113,8 @@ std::string GenReduceInitTensorNameOf(const std::string& tensor_name); bool IsReduceInitTensorName(const std::string& tensor_name); +bool IsSplitTransformTensorName(const std::string& tensor_name); + std::string GetOriginalReduceTensorName(const std::string& tensor_name); class ComputeOp; diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index f848d30515e21e..61e79787dea038 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -58,9 +58,9 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn, ReplaceConstParamToInteger(&copied->body); // Simplify already contains CastSimplify - Simplify(&copied->body); - EliminateInvariantLoop(&copied->body); - VLOG(4) << "After Optimize EliminateInvariantLoop:" << copied; + // Simplify(&copied->body); + // EliminateInvariantLoop(&copied->body); + // VLOG(4) << "After Optimize EliminateInvariantLoop:" << copied; ReplaceCrossThreadReduction(copied); VLOG(4) << "After Optimize ReplaceCrossThreadReduction:" << copied; ReplaceCrossBlockReduction(copied);