diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc index fd32b0a4a1c8a7..26c61c77845d5d 100644 --- a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc @@ -211,9 +211,16 @@ TileConfigMap BuildVectorizeConfig( }; bool is_sm_fully_utilized = true; - auto CheckSmUtilization = [sm_count, max_threads_per_sm, max_blocks_per_sm]( - int total_threads, int block_size) -> bool { - int blocks_needed = CeilDiv(total_threads, block_size); + // Only proceed with vectorization if SM utilization exceeds 100% + auto CheckSmUtilization = + [&](int input_size, int block_size, std::string last_dim) -> bool { + if (last_dim != "S" && last_dim != "R") { + VLOG(5) << "Invalid last_dim in SmUtilization Check: " << last_dim; + return false; + } + + int blocks_needed = + (last_dim == "S") ? CeilDiv(input_size, block_size) : input_size; int sms_needed = CeilDiv(blocks_needed, max_blocks_per_sm); float sm_utilization = static_cast(sms_needed) / sm_count; @@ -239,6 +246,8 @@ TileConfigMap BuildVectorizeConfig( if (warp_nums > 1 || spatial_numel < warp_nums * 64) { rd_thread_num = warp_nums * kWarpSize; if (CheckVectorize(reduce_numel, rd_thread_num, vectorize_factor)) { + is_sm_fully_utilized = + CheckSmUtilization(spatial_numel, rd_thread_num, "R"); break; } reduce_method = BlockReduceMethod(); @@ -257,7 +266,8 @@ TileConfigMap BuildVectorizeConfig( // warp_nums = Trim(warp_nums, 1, 32); sp_thread_num = kWarpSize * warp_nums; if (CheckVectorize(spatial_numel, sp_thread_num, vectorize_factor)) { - is_sm_fully_utilized = CheckSmUtilization(spatial_numel, sp_thread_num); + is_sm_fully_utilized = CheckSmUtilization( + spatial_numel / vectorize_factor, sp_thread_num, "S"); break; } }