Skip to content

Commit

Permalink
Fixed optimal_remainder beggining value
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jul 28, 2023
1 parent 2c346d4 commit 6ad3a4a
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/common/snippets/src/pass/common_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,25 @@ void CommonOptimizations::SplitDimensionM(const std::shared_ptr<ov::snippets::op
// Heuristic value for a quick exit from the algorithm.
// The value shows the number of threads in percentages that perform the most equal work
const auto optimal_thread_num_percent = 0.8;
size_t opt_remainder = batch_m_dim % optimal_parallelism_work_amount;
size_t optimal_remainder = 1;
auto get_remainder = [batch_dim, optimal_parallelism_work_amount](const size_t potential_batch_dim) {
return (batch_dim * potential_batch_dim) % optimal_parallelism_work_amount;
};

auto update_optimal_params = [&](size_t divisor_0, size_t divisor_1) {
if (divisor_1 < optimal_m_dim)
return;
const auto remainder = batch_dim * divisor_0 % optimal_parallelism_work_amount;
if (remainder > opt_remainder || remainder == 0) {
opt_remainder = remainder;
if (remainder > optimal_remainder || remainder == 0) {
optimal_remainder = remainder;
batch_m_dim = divisor_0;
new_m_dim = divisor_1;
}
};

// Firstly we have shape [batch, 1, m_dim, smth].
// So at the beginning we have parallel_work_amount = batch x 1
optimal_remainder = get_remainder(1);
const auto root = std::sqrt(m_dim) + 1;
for (size_t divisor_0 = 2; divisor_0 < root; ++divisor_0) {
const size_t divisor_1 = m_dim / divisor_0;
Expand All @@ -158,8 +164,8 @@ void CommonOptimizations::SplitDimensionM(const std::shared_ptr<ov::snippets::op

update_optimal_params(divisor_0, divisor_1);
update_optimal_params(divisor_1, divisor_0);
if ((static_cast<float>(opt_remainder) / static_cast<float>(optimal_parallelism_work_amount) > optimal_thread_num_percent) ||
(opt_remainder == 0)) {
if ((static_cast<float>(optimal_remainder) / static_cast<float>(optimal_parallelism_work_amount) > optimal_thread_num_percent) ||
(optimal_remainder == 0)) {
break;
}
}
Expand Down

0 comments on commit 6ad3a4a

Please sign in to comment.