Skip to content

Commit

Permalink
[CINN] apply vectorize Primitive in IRSchedule
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangX-21 committed Nov 27, 2024
1 parent b72c1ce commit bdffaac
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 4 deletions.
95 changes: 94 additions & 1 deletion paddle/cinn/ir/group_schedule/config/group_tile_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using TileConfigMap =
std::unordered_map<BucketInfo, TileConfig, BucketInfoHash>;

namespace {

const int kWarpSize = 32;
const int kMaxNumel = INT32_MAX;

int64_t CeilPow2(int64_t n) {
Expand Down Expand Up @@ -181,6 +181,83 @@ std::shared_ptr<ScheduleConfig::BaseInfo> InitBasicInfo(
return base_info;
}

TileConfigMap BuildVectorizeConfig(
const std::shared_ptr<ScheduleConfig::BaseInfo>& base_info,
const common::Target& target) {
// current only support [S, R] and [S]
const int iters_dim = base_info->iter_space_type.size();
if (iters_dim > 2) return {};
const auto& last_dim = base_info->iter_space_type.back().first;
const int sm_count = target.get_multi_processor_count();
int64_t spatial_numel = base_info->spatial_numel;
int64_t reduce_numel = base_info->reduce_numel;
ReduceMethod reduce_method = NoneReduceMethod();
int vectorize_factor = 1;
const std::vector<int> vectorize_factors{4, 2};

auto const CheckVectorize = [&](int nums, int threads, int factor) {
const int deal_elements_in_warp = threads * factor;
if (nums % deal_elements_in_warp == 0) {
vectorize_factor = factor;
base_info->enable_vectorize = true;
return true;
}
return false;
};

int64_t sp_thread_num = 1;
int64_t rd_thread_num = 1;
int64_t warp_nums = 1;
// Reduce Region
if (last_dim == "R") {
for (auto factor : vectorize_factors) {
vectorize_factor = factor;
const int elements_in_warp = kWarpSize * vectorize_factor;
warp_nums = CeilDiv(reduce_numel, elements_in_warp);
warp_nums = Trim(warp_nums, 1, 8);
if (warp_nums > 1 || spatial_numel < warp_nums * 64) {
rd_thread_num = warp_nums * kWarpSize;
if (CheckVectorize(reduce_numel, rd_thread_num, vectorize_factor)) {
break;
}
reduce_method = BlockReduceMethod();
}
}
} else if (iters_dim == 1 && last_dim == "S") { // Spatial Region
for (auto factor : vectorize_factors) {
vectorize_factor = factor;
const int elements_in_warp = kWarpSize * vectorize_factor;
warp_nums = CeilDiv(spatial_numel, elements_in_warp);
warp_nums = Trim(warp_nums, 1, 16);
sp_thread_num = kWarpSize * warp_nums;
if (CheckVectorize(spatial_numel, sp_thread_num, vectorize_factor)) {
break;
}
}
}

if (!base_info->enable_vectorize) return {};

int64_t sp_inner_num = [&]() -> int64_t {
if (rd_thread_num > 1) return 1;
spatial_numel = spatial_numel / sp_thread_num / vectorize_factor;
int64_t expected = spatial_numel / (sm_count * 16);
return Trim(expected, 1, 4);
}();

int64_t sp_upper_bound = base_info->spatial_numel > 1 ? kMaxNumel : 1;
int64_t rd_upper_bound = base_info->reduce_numel > 1 ? kMaxNumel : 1;
BucketInfo bucket_info{1, sp_upper_bound, 1, rd_upper_bound};
warp_nums = Trim(sp_thread_num * rd_thread_num / kWarpSize, 1, 32);
TileConfig tile_config{warp_nums,
/* tree_reduce_num = */ rd_thread_num,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ sp_inner_num,
/* vectorize_factor = */ vectorize_factor,
reduce_method};
return {{bucket_info, tile_config}};
}

TileConfigMap BuildPureStaticShapeConfig(
const std::shared_ptr<ScheduleConfig::BaseInfo>& base_info,
const common::Target& target) {
Expand All @@ -190,6 +267,10 @@ TileConfigMap BuildPureStaticShapeConfig(
int64_t reduce_numel = base_info->reduce_numel;
ReduceMethod reduce_method = NoneReduceMethod();

// Try to use vectorization first
auto config_map = BuildVectorizeConfig(base_info, target);
if (!config_map.empty()) return std::move(config_map);

// 1. Allocate spatial/reduce threads
// Principals:
// 1) The low 32 threads are assigned to the last dimension to ensure
Expand Down Expand Up @@ -259,6 +340,7 @@ TileConfigMap BuildPureStaticShapeConfig(
/* tree_reduce_num = */ rd_thread_num,
/* grid_reduce_num = */ rd_block_num,
/* spatial_inner_num = */ sp_inner_num,
/* vectorize_factor = */ 1,
reduce_method};
return {{bucket_info, tile_config}};
}
Expand All @@ -277,6 +359,7 @@ TileConfigMap BuildStaticSpatialConfig(
/* tree_reduce_num = */ 256,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
BlockReduceMethod()};
return {{bucket_info, tile_config}};
} else {
Expand All @@ -290,6 +373,7 @@ TileConfigMap BuildStaticSpatialConfig(
/* tree_reduce_num = */ 32,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
WarpReduceMethod()};

BucketInfo bucket_info_257_2048{/* sp_lower_bound = */ 1,
Expand All @@ -302,6 +386,7 @@ TileConfigMap BuildStaticSpatialConfig(
/* tree_reduce_num = */ 128,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
BlockReduceMethod()};

BucketInfo bucket_info_2049_INF{/* sp_lower_bound = */ 1,
Expand All @@ -314,6 +399,7 @@ TileConfigMap BuildStaticSpatialConfig(
/* tree_reduce_num = */ 256,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
BlockReduceMethod()};

return {{bucket_info_1_256, tile_config_1_256},
Expand All @@ -336,6 +422,7 @@ TileConfigMap BuildStaticReduceConfig(
/* tree_reduce_num = */ 1,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
NoneReduceMethod()};
BucketInfo bucket_info__1024_1M{/* sp_lower_bound = */ 1024,
/* sp_upper_bound = */ 1024 * 1024 - 1,
Expand All @@ -347,6 +434,7 @@ TileConfigMap BuildStaticReduceConfig(
/* tree_reduce_num = */ 1,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 4,
/* vectorize_factor = */ 1,
NoneReduceMethod()};
BucketInfo bucket_info__1M_INF{/* sp_lower_bound = */ 1024 * 1024,
/* sp_upper_bound = */ kMaxNumel,
Expand All @@ -358,6 +446,7 @@ TileConfigMap BuildStaticReduceConfig(
/* tree_reduce_num = */ 1,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 4,
/* vectorize_factor = */ 1,
NoneReduceMethod()};
return {{bucket_info__1_1023, tile_config__1_1023},
{bucket_info__1024_1M, tile_config__1024_1M},
Expand All @@ -374,6 +463,7 @@ TileConfigMap BuildStaticReduceConfig(
/* tree_reduce_num = */ 32,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ (256 / CeilPow2(base_info->reduce_numel)),
/* vectorize_factor = */ 1,
WarpReduceMethod()};
return {{bucket_info, tile_config}};
} else if (base_info->reduce_numel <= 2048) {
Expand All @@ -392,6 +482,7 @@ TileConfigMap BuildStaticReduceConfig(
tree_reduce_num,
/* grid_reduce_num = */ 1,
/* spatial_inner_num */ 1,
/* vectorize_factor = */ 1,
BlockReduceMethod()};
return {{bucket_info, tile_config}};
} else {
Expand All @@ -405,6 +496,7 @@ TileConfigMap BuildStaticReduceConfig(
/* tree_reduce_num = */ 1024,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
BlockReduceMethod()};
return {{bucket_info, tile_config}};
}
Expand All @@ -423,6 +515,7 @@ TileConfigMap BuildDynamicShapeConfig(
/* tree_reduce_num = */ 1024,
/* grid_reduce_num = */ 1,
/* spatial_inner_num = */ 1,
/* vectorize_factor = */ 1,
BlockReduceMethod()};
return {{bucket_info, tile_config}};
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/config/group_tile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct ScheduleConfig {
bool has_dynamic_spatial{false};
bool has_dynamic_reduce{false};
bool can_apply_grid_reduce{false};
bool enable_vectorize{false};
IterSpaceType iter_space_type;
};

Expand All @@ -48,6 +49,7 @@ struct ScheduleConfig {
int64_t tree_reduce_num{1};
int64_t grid_reduce_num{1};
int64_t spatial_inner_num{1};
int64_t vectorize_factor{1};
ReduceMethod reduce_method{NoneReduceMethod()};
};

Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ class ComputeInlineTactic final : public ScheduleTactic {
private:
std::unordered_set<std::string> output_names_;
cinn::common::Target target_;
bool enable_vectorize_{false};
};

void ComputeInlineTactic::Init(ScheduleContext* context) {
output_names_ = context->output_names;
target_ = context->target;
enable_vectorize_ = context->config.base_info->enable_vectorize;
}

void ComputeInlineTactic::Apply(ir::IRSchedule* sch,
Expand All @@ -50,6 +52,8 @@ void ComputeInlineTactic::Apply(ir::IRSchedule* sch,
// if (IsProhibitScheduleExternCallBlock(node->Block())) {
// return;
// }
// compute inline tactic not work, when apply vectorize in current schedule
if (enable_vectorize_) return;
auto_schedule::AutoInline inliner(target_, output_names_);
VLOG(6) << "try ComputeInline on: " << block_id
<< ", before ComputeInline, func body: "
Expand Down
106 changes: 106 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ namespace ir {

using cinn::ir::analyzer::IsReductionSBlock;

bool IsSpatialRegion(const ScheduleConfig& config) {
if (config.base_info->iter_space_type.size() == 1 &&
config.base_info->iter_space_type.back().first == "S") {
return true;
}
return false;
}

bool UseContinuousDataTile(const ScheduleConfig& config) {
// use continuous data tile for [S] and [...R]
if (config.base_info->iter_space_type.size() == 1 &&
Expand All @@ -37,13 +45,29 @@ bool UseContinuousDataTile(const ScheduleConfig& config) {
return false;
}

bool ScheduleBlockEnableVectorize(const ScheduleConfig& config,
const std::string& block_id) {
if (!config.base_info->enable_vectorize) return false;

// currently, dont't support vectorize for SplitedTransformTensor
// ScheduleBlock
if (ir::IsSplitTransformTensorName(block_id)) return false;

if (!UseContinuousDataTile(config)) return false;

// TODO(ZhangX): check tensor indexs contains vectorize axis

return true;
}

class TileFirstGeneralTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;
void ApplyContinuousDataTile(ir::IRSchedule* sch,
const std::string& block_id);
void ApplyVectorize(ir::IRSchedule* sch, const std::string& block_id);

std::string TacticName() const override { return "TileFirstGeneralTactic"; }

Expand Down Expand Up @@ -112,6 +136,12 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (ir::IsReduceInitTensorName(block_id)) return;

// loops tiling with vectorize
if (ScheduleBlockEnableVectorize(context_->config, block_id)) {
ApplyVectorize(sch, block_id);
return;
}

AlignToReduceInput(sch, block_id);
VLOG(6) << "After AlignToReduceInput on block: [" << block_id
<< "], loop nest:\n"
Expand Down Expand Up @@ -488,6 +518,82 @@ void TileFirstGeneralTactic::BindCudaInfo(ir::IRSchedule* sch,
}
}

void TileFirstGeneralTactic::ApplyVectorize(ir::IRSchedule* sch,
const std::string& block_id) {
const auto sp_thread = context_->config.tile_config.warp_num * 32 /
context_->config.tile_config.tree_reduce_num;
const auto sp_loop = context_->config.tile_config.spatial_inner_num;
const auto vectorize_factor = context_->config.tile_config.vectorize_factor;
const auto rd_thread = context_->config.tile_config.tree_reduce_num;
const auto rd_block = context_->config.tile_config.grid_reduce_num;
VLOG(4) << "ApplyContinuousDataTile sp_thread=" << sp_thread;
VLOG(4) << "ApplyContinuousDataTile sp_loop=" << sp_loop;
VLOG(4) << "ApplyContinuousDataTile rd_thread=" << rd_thread;
VLOG(4) << "ApplyContinuousDataTile rd_block=" << rd_block;
// Merge reduce axes
MergeReduceAxis(sch, block_id);
VLOG(4) << "After MergeReduceAxis on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// Merge spatial axes
MergeFlattenAxis(sch, block_id);
VLOG(4) << "After MergeFlattenAxis on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetModule().GetExprs().front();

// Spatial situation
// deal with spation block
if (IsSpatialRegion(context_->config)) {
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
sch->Bind(loops[0], "blockIdx.x");
if (sp_loop > 1) {
sch->Bind(loops[2], "threadIdx.x");
} else {
sch->Bind(loops[1], "threadIdx.x");
}
};

auto loops = sch->GetLoops(block_id);
if (sp_loop > 1) {
sch->Split(loops[0],
std::vector<int>{-1, sp_loop, sp_thread, vectorize_factor});
} else {
sch->Split(loops[0], std::vector<int>{-1, sp_thread, vectorize_factor});
}

// set vectorize schedule primitives
loops = sch->GetLoops(block_id);
auto vectorize_axis = loops.size() - 1;
sch->Vectorize(loops[vectorize_axis], vectorize_factor);

loops = sch->GetLoops(block_id);
DoBind(loops);
return;
}

// Reduce situation
// only deal with spatial block and don't support blockIdx.y
if (!IsReductionSBlock(sch->GetBlock(block_id))) {
auto loops = sch->GetLoops(block_id);
sch->Split(loops[1], std::vector<int>{-1, rd_thread, vectorize_factor});

// set vectorize schedule primitives
loops = sch->GetLoops(block_id);
auto vectorize_axis = loops.size() - 1;
sch->Vectorize(loops[vectorize_axis], vectorize_factor);
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
sch->Bind(loops[0], "blockIdx.x");
auto threadsIdx_x_axis = vectorize_axis - 1;
sch->Bind(loops[threadsIdx_x_axis], "threadIdx.x");
};
loops = sch->GetLoops(block_id);
DoBind(loops);
return;
}
return;
}

std::unique_ptr<ScheduleTactic> CreateTileFirstGeneralTactic() {
return std::make_unique<TileFirstGeneralTactic>();
}
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,10 @@ bool IsReduceInitTensorName(const std::string &tensor_name) {
reduce_init_suffix.size()) == reduce_init_suffix;
}

bool IsSplitTransformTensorName(const std::string &tensor_name) {
return tensor_name.find("_split_transform") != std::string::npos;
}

std::string GetOriginalReduceTensorName(const std::string &tensor_name) {
std::string reduce_init_suffix(kReduceInitSuffix);
if (IsReduceInitTensorName(tensor_name)) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ std::string GenReduceInitTensorNameOf(const std::string& tensor_name);

bool IsReduceInitTensorName(const std::string& tensor_name);

bool IsSplitTransformTensorName(const std::string& tensor_name);

std::string GetOriginalReduceTensorName(const std::string& tensor_name);

class ComputeOp;
Expand Down
Loading

0 comments on commit bdffaac

Please sign in to comment.