From 6d066772148492b3e6219344c49663c9dbd53fa2 Mon Sep 17 00:00:00 2001 From: Zhou Xin Date: Thu, 26 Dec 2024 14:49:37 +0800 Subject: [PATCH] [CINN][Add Backend Pass Comment No.6] Add comment for longlong2int (#70457) --- paddle/cinn/optim/longlong2int.h | 68 +++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/paddle/cinn/optim/longlong2int.h b/paddle/cinn/optim/longlong2int.h index b72e70df603a82..e1d598c33abe04 100644 --- a/paddle/cinn/optim/longlong2int.h +++ b/paddle/cinn/optim/longlong2int.h @@ -18,7 +18,73 @@ namespace cinn { namespace optim { -// Try to change the type of longlong to int in the expr. +/** + * Converts int64 (long long) types to int32 in a block where possible. + * + * This pass is applicable in scenarios where the IR contains int64 types that + * can be safely represented as int32 without overflow. + * + * When applied, this pass will traverse the IR and convert int64 types to int32 + * in various constructs, including: + * - Tensor shapes and indices + * - Loop variables and bounds + * - Buffer metadata (shapes, strides, offsets) + * - Comparison operations + * + * Overflow checking: + * The pass performs overflow checking primarily for nested for-loops. This + * focus on nested loops is based on the assumption that they are the most + * common source of potential overflows in typical computational kernels. The + * check considers: + * - The product of loop extents (iteration counts) + * - Whether loop bounds are constant and of index type + * + * + * Examples: + * 1. Loop variable conversion: + * Before conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22ll), i2(0:16ll)]) + * write_buffers(_var_2[i0(0:22ll), i1(0:16ll), i2(0:16ll), + * i3(0:16ll)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16ll] + * } + * } + * } + * } + * } + * + * After conversion: + * { + * ScheduleBlock(root_12) + * { + * attrs(tile_method:TileFirstGeneralTactic) + * thread_bind[blockIdx.x] for (blockIdx.x, 0, 352) + * { + * thread_bind[threadIdx.x] for (threadIdx.x, 0, 256) + * { + * ScheduleBlock(var_2) + * { + * i0, i1, i2, i3 = axis.bind(idx / 4096, (idx % 4096) / 256, (idx % + * 256) / 16, idx % 16) read_buffers(_var[i0(0:22), i2(0:16)]) + * write_buffers(_var_2[i0(0:22), i1(0:16), i2(0:16),i3(0:16)]) + * var_2[i0, i1, i2, i3] = var[i0, i2, i3 + i1 * 16] + * } + * } + * } + * } + * } + */ void TryCastLonglong2Int(Expr* expr); } // namespace optim } // namespace cinn