From e26973d14e49a020ee7ed27b072207d64bbd73b8 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Sun, 20 Oct 2024 22:21:32 -0700 Subject: [PATCH 1/3] [TKW] Add support for multiple/local reduceOp Signed-off-by: Stanley Winata --- iree/turbine/kernel/ops/wave_ops.py | 28 ++++++++- .../kernel/wave/decompose_reduce_ops.py | 37 ++++++++--- iree/turbine/kernel/wave/utils.py | 6 ++ lit_tests/kernel/wave/codegen.py | 62 +++++++++++++++++++ 4 files changed, 121 insertions(+), 12 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 19e3f64c..d0c96637 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1236,19 +1236,41 @@ class ReduceOp(CustomOp, ABC): dim: which symbolic dim to reduce. """ - arg: fx.Node + arg: fx.Node | list[fx.Node] init: fx.Node = None dim: Optional[Any] = None @property def indexing_dims(self) -> list[IndexSymbol]: - src_indexing = get_custom(self.arg).indexing_dims + # Local import to break circular dep. + from ..wave.utils import all_equal + + if isinstance(self.arg, Sequence): + src_indexings = [get_custom(arg).indexing_dims for arg in self.arg] + if not all_equal(src_indexings): + raise NotImplementedError( + "NYI: Only support case where all inputs to ReduceOp to have same indexing dim." + ) + src_indexing = src_indexings[0] + else: + src_indexing = get_custom(self.arg).indexing_dims dst_indexing = [dim for dim in src_indexing if dim != self.dim] return dst_indexing @property def type(self) -> Memory: - src_type = get_custom(self.arg).type + if isinstance(self.arg, Sequence): + # Local import to break circular dep. + from ..wave.utils import all_equal + + src_types = [get_custom(arg).type for arg in self.arg] + if not all_equal(src_types): + raise NotImplementedError( + "NYI: Only support case where all inputs to ReduceOp to have same type." + ) + src_type = src_types[0] + else: + src_type = get_custom(self.arg).type reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] dst_type = Register[*reduced_dims, src_type.dtype] return dst_type diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index 0be318ef..5c54d935 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -23,7 +23,7 @@ Reduction, ) -from .utils import DCE, subs_idxc +from .utils import DCE, subs_idxc, all_equal import torch.fx as fx import math from typing import Callable @@ -38,12 +38,16 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph): def emit_local_reduction( - binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int + binary_fn: Callable, src: list[fx.Node], graph: fx.Graph, local_reduction_size: int ) -> fx.Node: - init = get_graph_node(Extract(src, [0]), graph) - for i in range(1, local_reduction_size): - cur_slice = get_graph_node(Extract(src, [i]), graph) - init = get_graph_node(binary_fn(init, cur_slice), graph) + init = None + for i in range(len(src)): + for j in range(local_reduction_size): + if init is None: + init = get_graph_node(Extract(src[i], [j]), graph) + continue + cur_slice = get_graph_node(Extract(src[i], [j]), graph) + init = get_graph_node(binary_fn(init, cur_slice), graph) return init @@ -98,9 +102,18 @@ def decompose_reduce_ops( raise ValueError( "No reduction dim specified, please specify a reduction dim." ) + if not isinstance(reduction_src, (list, tuple)): + reduction_src = [reduction_src] # Local Reduce - if reduction_dim is not get_custom(custom.arg).type.symbolic_shape[-1]: + src_fastest_dims = [ + get_custom(arg).type.symbolic_shape[-1] for arg in reduction_src + ] + if not all_equal(src_fastest_dims): + raise NotImplementedError( + "NYI: Expect all reduce_src to have same fastest dim." + ) + if reduction_dim is not src_fastest_dims[0]: raise NotImplementedError( "Only implemented reduction on fastest dimension." ) @@ -108,9 +121,15 @@ def decompose_reduce_ops( get_thread_shape = lambda index: max( subs_idxc(x.size) for x in index.values() ) - local_reduction_size = get_thread_shape(get_custom(custom.arg).index) + local_reduce_sizes = [ + get_thread_shape(get_custom(arg).index) for arg in reduction_src + ] + if not all_equal(local_reduce_sizes): + raise NotImplementedError( + "NYI: Expect all reduce_src to have same local reduce size." + ) local_reduction = emit_local_reduction( - binary_fn, reduction_src, custom.graph, local_reduction_size + binary_fn, reduction_src, custom.graph, local_reduce_sizes[0] ) # Global Reduce diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index f3c9201d..674b8d5b 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -739,3 +739,9 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int: return 4 case MMAType.F32_32x32x16_F8: return 16 + + +def all_equal(input_list: list[Any]) -> bool: + if len(input_list) == 0: + return True + return all(elem == input_list[0] for elem in input_list) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 5778b689..d048f4cd 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1211,6 +1211,68 @@ def test( # CHECK: arith.addf {{.*}} : vector<1xf16> +# Tests for multiple local reduction, and we to emit and iteratively slice and reduce over multiple variables correctly. +@run_test +def test_mutliple_local_reduce_sum(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD) + res = tkw.sum([lhs, rhs], dim=N) + tkw.write(res, c, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 128) + a = torch.randn(shape, dtype=torch.float16) + b = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 1, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, b, c).module_op) + # CHECK: %[[LHS:.+]] = vector.load {{.*}} : memref<256x128xf16 + # CHECK: %[[RHS:.+]] = vector.load {{.*}} : memref<256x128xf16 + # CHECK: %[[LHS0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [1], strides = [1]} + # CHECK: %[[LHS1:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [1], sizes = [1], strides = [1]} + # CHECK: %[[REDUC_0:.+]] = arith.addf %[[LHS0]], %[[LHS1]] : vector<1xf16> + # CHECK: %[[RHS0:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [1], strides = [1]} + # CHECK: %[[REDUC_1:.+]] = arith.addf %[[REDUC_0]], %[[RHS0]] : vector<1xf16> + # CHECK: %[[RHS1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [1], sizes = [1], strides = [1]} + # CHECK: %[[REDUC_2:.+]] = arith.addf %[[REDUC_1]], %[[RHS1]] : vector<1xf16> + + # This test is to ensure that the propagation of indexing_dims between reduction and operations # outside the reduction is working properly. @run_test From 8ce9b098e33bf5dda0ed02743171635b8e8d8350 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Wed, 23 Oct 2024 14:24:52 -0700 Subject: [PATCH 2/3] Add local sources reduction and update lit Signed-off-by: Stanley Winata --- .../kernel/wave/decompose_reduce_ops.py | 36 ++++++++++++------- lit_tests/kernel/wave/codegen.py | 15 ++++---- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index 5c54d935..6fecfeed 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -37,17 +37,23 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph): return custom +def emit_sources_reduction( + binary_fn: Callable, src: list[fx.Node], graph: fx.Graph +) -> fx.Node: + init = src[0] + for i in range(1, len(src)): + init = get_graph_node(binary_fn(init, src[i]), graph) + init.index = src[0].index + return init + + def emit_local_reduction( - binary_fn: Callable, src: list[fx.Node], graph: fx.Graph, local_reduction_size: int + binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int ) -> fx.Node: - init = None - for i in range(len(src)): - for j in range(local_reduction_size): - if init is None: - init = get_graph_node(Extract(src[i], [j]), graph) - continue - cur_slice = get_graph_node(Extract(src[i], [j]), graph) - init = get_graph_node(binary_fn(init, cur_slice), graph) + init = get_graph_node(Extract(src, [0]), graph) + for i in range(1, local_reduction_size): + cur_slice = get_graph_node(Extract(src, [i]), graph) + init = get_graph_node(binary_fn(init, cur_slice), graph) return init @@ -71,11 +77,12 @@ def decompose_reduce_ops( ): """ The lowering for multi_reduction is done in two steps: - 1. Local Reduce: Each thread reduces all elements carried by it along + 1. Source Reduce: Each thread reduce locally all it's sources. + 2. Local Reduce: Each thread reduces all elements carried by it along the reduction dimensions. - 2. Thread Reduce: Each thread reduces result of step 1 across threads + 3. Thread Reduce: Each thread reduces result of step 1 across threads by doing a butterfly shuffle. - 3. Accumulator Reduce: Each thread reduces it's intermediate reduced + 4. Accumulator Reduce: Each thread reduces it's intermediate reduced results with the accumulator it holds. """ # Get reducte nodes. @@ -128,8 +135,11 @@ def decompose_reduce_ops( raise NotImplementedError( "NYI: Expect all reduce_src to have same local reduce size." ) + src_reduction = emit_sources_reduction( + binary_fn, reduction_src, custom.graph + ) local_reduction = emit_local_reduction( - binary_fn, reduction_src, custom.graph, local_reduce_sizes[0] + binary_fn, src_reduction, custom.graph, local_reduce_sizes[0] ) # Global Reduce diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index d048f4cd..61e08298 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1264,13 +1264,14 @@ def test( print(test(a, b, c).module_op) # CHECK: %[[LHS:.+]] = vector.load {{.*}} : memref<256x128xf16 # CHECK: %[[RHS:.+]] = vector.load {{.*}} : memref<256x128xf16 - # CHECK: %[[LHS0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [1], strides = [1]} - # CHECK: %[[LHS1:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [1], sizes = [1], strides = [1]} - # CHECK: %[[REDUC_0:.+]] = arith.addf %[[LHS0]], %[[LHS1]] : vector<1xf16> - # CHECK: %[[RHS0:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0], sizes = [1], strides = [1]} - # CHECK: %[[REDUC_1:.+]] = arith.addf %[[REDUC_0]], %[[RHS0]] : vector<1xf16> - # CHECK: %[[RHS1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [1], sizes = [1], strides = [1]} - # CHECK: %[[REDUC_2:.+]] = arith.addf %[[REDUC_1]], %[[RHS1]] : vector<1xf16> + # Reduce all sources locally. + # CHECK: %[[SRC_REDUC:.+]] = arith.addf %[[LHS]], %[[RHS]] : vector<2xf16> + # Do Local Reductions. + # CHECK: %[[LOCAL_REDUC0:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [0], sizes = [1], strides = [1]} + # CHECK: %[[LOCAL_REDUC1:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [1], sizes = [1], strides = [1]} + # CHECK: %[[REDUC_0:.+]] = arith.addf %[[LOCAL_REDUC0]], %[[LOCAL_REDUC1]] : vector<1xf16> + # Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor # This test is to ensure that the propagation of indexing_dims between reduction and operations From 5d4d0e33d3d2b29aa03bc151fbbf65ef990d53aa Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Thu, 24 Oct 2024 10:05:43 -0700 Subject: [PATCH 3/3] nit rename step Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/decompose_reduce_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index 6fecfeed..bf972b75 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -80,7 +80,7 @@ def decompose_reduce_ops( 1. Source Reduce: Each thread reduce locally all it's sources. 2. Local Reduce: Each thread reduces all elements carried by it along the reduction dimensions. - 3. Thread Reduce: Each thread reduces result of step 1 across threads + 3. Thread Reduce: Each thread reduces result of step 2 across threads by doing a butterfly shuffle. 4. Accumulator Reduce: Each thread reduces it's intermediate reduced results with the accumulator it holds.