Skip to content

Commit

Permalink
[CINN] support ForOp with muti schedule blocks in vectorize!
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangX-21 committed Dec 17, 2024
1 parent 2dad2d3 commit b9a0da5
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 38 deletions.
71 changes: 36 additions & 35 deletions paddle/cinn/hlir/framework/pir/trivial_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -829,45 +829,46 @@ bool GetCanApplyVectorize(const std::vector<ir::Expr>& op_compute_bodies) {
},
/* uniq_target = */ false);

auto CheckTensorIsBroadcastAndContinuous =
[&](const std::vector<Expr>& indices) {
int loop_idx = 0;
bool is_broadcast = false;
for (int i = 0; i < indices.size(); ++i) {
const ir::Expr& index = indices[i];
if (index.is_constant()) {
is_broadcast = true;
continue;
}

if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
if (!iter_var2value.count(iter_var)) {
return false;
}
ir::Expr iter_value = iter_var2value.at(iter_var);
PADDLE_ENFORCE_EQ(
iter_value.as_var() || iter_value.is_constant(),
true,
::common::errors::PreconditionNotMet(
"Required iter_value shall be var or constant type."));
for (; loop_idx < for_iters.size(); ++loop_idx) {
if (for_iters[loop_idx] == iter_value.as_var_ref()) {
break;
}
}

if (loop_idx == for_iters.size()) {
return false;
}
auto CheckTensorIsBroadcastAndContinuous = [&](std::vector<Expr>& indices) {
int loop_idx = 0;
bool is_broadcast = false;
for (int i = 0; i < indices.size(); ++i) {
ir::Expr& index = indices[i];
cinn::optim::Simplify(&index);
if (index.is_constant()) {
is_broadcast = true;
continue;
}

if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
if (!iter_var2value.count(iter_var)) {
return false;
}
ir::Expr iter_value = iter_var2value.at(iter_var);
PADDLE_ENFORCE_EQ(
iter_value.as_var() || iter_value.is_constant(),
true,
::common::errors::PreconditionNotMet(
"Required iter_value shall be var or constant type."));
for (; loop_idx < for_iters.size(); ++loop_idx) {
if (for_iters[loop_idx] == iter_value.as_var_ref()) {
break;
}
if (is_broadcast) return true;
}

if (loop_idx == for_iters.size()) {
return false;
};
}
}
if (is_broadcast || indices.size() < for_iters.size()) return true;
return false;
};

auto CheckoutTensorIsContinuous = [&](const std::vector<Expr>& indices) {
auto CheckoutTensorIsContinuous = [&](std::vector<Expr>& indices) {
for (int i = 0; i < indices.size(); ++i) {
const ir::Expr& index = indices[i];
ir::Expr& index = indices[i];
cinn::optim::Simplify(&index);
if (index.is_constant()) return false;
if (!index.is_var()) return false;
ir::Var iter_var = index.as_var_ref();
Expand Down
8 changes: 5 additions & 3 deletions paddle/cinn/ir/group_schedule/config/group_tile_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,24 @@ TileConfigMap BuildVectorizeConfig(
vectorize_factor = factor;
const int elements_in_warp = kWarpSize * vectorize_factor;
warp_nums = CeilDiv(spatial_numel, elements_in_warp);
warp_nums = Trim(warp_nums, 1, 16);
warp_nums = Trim(warp_nums, 1, 32);
sp_thread_num = kWarpSize * warp_nums;
if (CheckVectorize(spatial_numel, sp_thread_num, vectorize_factor)) {
break;
}
}
}

if (!can_vectorize) {
base_info->can_apply_vectorize = false;
return {};
}

int64_t sp_inner_num = [&]() -> int64_t {
if (rd_thread_num > 1) return 1;
spatial_numel = spatial_numel / sp_thread_num / vectorize_factor;
int64_t expected = spatial_numel / (sm_count * 16);
return Trim(expected, 1, 4);
int64_t expected = spatial_numel / (sm_count * 64);
return Trim(expected, 1, 1);
}();

int64_t sp_upper_bound = base_info->spatial_numel > 1 ? kMaxNumel : 1;
Expand Down
57 changes: 57 additions & 0 deletions paddle/cinn/optim/vectorize_for_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,59 @@ std::unordered_map<std::string, ir::Var> CollectExprSymbols(Expr *x) {
return std::move(mutator.GetSymbols());
}

class ForOpWithMultiScheduleBlockSupportVectorize
: public ir::IRMutator<ir::Expr *> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override {
auto *node = expr->As<ir::ScheduleBlockRealize>();
PADDLE_ENFORCE_NOT_NULL(
node,
::common::errors::InvalidArgument("The input expr should be a Block"));
IRMutator<>::Visit(op, expr);
if (in_vectorize_scope) {
for_op_blocks_.push_back(expr);
}
}

void Visit(const ir::For *op, ir::Expr *expr) override {
auto *forloop = expr->As<ir::For>();
if (forloop->is_vectorized()) in_vectorize_scope = true;

IRMutator<>::Visit(op, expr);

if (for_op_blocks_.size() > 1 && in_vectorize_scope) {
std::vector<Expr> stmts;
for (auto block : for_op_blocks_) {
Var new_iterator(
cinn::common::UniqName(forloop->loop_var->name + "_s"));

cinn::ir::ir_utils::IrReplaceVarBroadcast(
block, forloop->loop_var, Expr(new_iterator));

ir::Expr f_expr = ir::For::Make(new_iterator,
forloop->min,
forloop->extent,
forloop->for_type(),
forloop->device_api,
ir::Block::Make({*block}),
forloop->vectorize_info(),
forloop->bind_info());
stmts.push_back(f_expr);
}
Expr block_expr = ir::Block::Make(stmts);
*expr = block_expr;
}
in_vectorize_scope = false;
for_op_blocks_.clear();
}

bool in_vectorize_scope{false};
std::vector<ir::Expr *> for_op_blocks_;
};

class ScheduleBlockTensorVectorizeTeller : public ir::IRMutator<Expr *> {
public:
ScheduleBlockTensorVectorizeTeller(Var iter_var, const int factor)
Expand Down Expand Up @@ -539,6 +592,10 @@ class VectorizeForTransMutator : public ir::IRMutator<ir::Expr *> {
} // namespace

void VectorizeForTrans(Expr *expr) {
ForOpWithMultiScheduleBlockSupportVectorize update;
VLOG(5) << "before multi schedule block deal with vectorize " << *expr;
update(expr);
VLOG(5) << "after multi schedule block deal with vectorize " << *expr;
VectorizeForTransMutator collector;
VLOG(5) << "before vectorize for trans " << *expr;
collector(expr);
Expand Down

0 comments on commit b9a0da5

Please sign in to comment.