From db1ec5793848c242cce6c648def19086a1d9d597 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Wed, 6 Nov 2024 11:46:25 -0800 Subject: [PATCH] Add type inference (#252) This PR adds a type inference pass to wave. Previously, the types were infered by looking up types from neighbors resulting in inefficient type inference. Instead, we now introduce a pass that infers the types for all operators in the graph and the inferred type is then stores in the node. New nodes that are constructed in downstream passes are responsible for annotating types for the new operators. --------- Signed-off-by: Harsh Menon --- iree/turbine/kernel/ops/wave_ops.py | 94 +++++---- iree/turbine/kernel/wave/expansion.py | 7 +- iree/turbine/kernel/wave/type_inference.py | 21 ++ iree/turbine/kernel/wave/wave.py | 4 + lit_tests/kernel/wave/barriers.py | 3 + lit_tests/kernel/wave/expansion.py | 10 + .../kernel/wave/index_sequence_analysis.py | 2 + .../kernel/wave/minimize_global_loads.py | 2 + lit_tests/kernel/wave/promotion.py | 6 +- lit_tests/kernel/wave/scheduling.py | 2 + tests/kernel/wave/scheduling_test.py | 2 + tests/kernel/wave/type_inference_test.py | 199 ++++++++++++++++++ tests/kernel/wave/visualization_test.py | 2 + 13 files changed, 311 insertions(+), 43 deletions(-) create mode 100644 iree/turbine/kernel/wave/type_inference.py create mode 100644 tests/kernel/wave/type_inference_test.py diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 89b96ccf..6e38ea66 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -338,7 +338,7 @@ def custom_string(self, value_map: dict[str, str]) -> str: vars_str = ", ".join(vars_list) return f"{self.tkw_op_name}({vars_str})" - def add_to_graph(self, region_graph: RegionGraph) -> fx.Node: + def add_to_graph(self, region_graph: RegionGraph, type: Any = None) -> fx.Node: arg_list = tuple([value for _, value in vars(self).items()]) self.graph = region_graph self.fx_node = region_graph.create_node( @@ -350,6 +350,10 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node: self.fx_node.tkw_op = self.__class__ self.fx_node.tkw_op_name = self.tkw_op_name self.fx_node.index = None + if type is None: + get_custom(self.fx_node).infer_type() + else: + self.fx_node.type = type return self.fx_node def _add_proxy_to_graph(self, region_graph: RegionGraph): @@ -556,6 +560,23 @@ def vector_shapes(self) -> dict[IndexSymbol, int]: def vector_shapes(self, value: dict[IndexSymbol, int]): self.fx_node.vector_shapes = value + @property + def type(self) -> Any: + if hasattr(self.fx_node, "type"): + return self.fx_node.type + return None + + @type.setter + def type(self, value: Any): + self.fx_node.type = value + + def infer_type(self): + """ + Infer the type of this operator using the types + of its arguments. + """ + pass + def align_index(self, constraints: list["Constraint"]) -> None: """ Align index to WG/Tile sizes. @@ -602,13 +623,13 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - @property - def type(self) -> Memory: + def infer_type(self): lhs_type = get_custom(self.lhs).type rhs_type = get_custom(self.rhs).type has_same_type = has_same_custom_type(lhs_type, rhs_type) if has_same_type: - return lhs_type + self.type = lhs_type + return lhs_dim_set = set(lhs_type.symbolic_shape) rhs_dim_set = set(rhs_type.symbolic_shape) if lhs_dim_set.isdisjoint(rhs_dim_set): @@ -616,7 +637,7 @@ def type(self) -> Memory: "BinaryPyOp requires lhs and rhs shape to be at least broadcastable." ) broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type - return broadcasted_type + self.type = broadcasted_type @define_interface_op("exp2") @@ -637,10 +658,9 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - @property - def type(self) -> Memory: + def infer_type(self): src_type = get_custom(self.arg).type - return src_type + self.type = src_type @final @@ -868,9 +888,8 @@ def rhs_type(self) -> Memory: def acc_type(self) -> Memory: return get_custom(self.acc).type - @property - def type(self) -> Memory: - return self.acc_type + def infer_type(self): + self.type = self.acc_type def operand_index( self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr] @@ -925,6 +944,7 @@ def reduction_dim(self, value: IndexSymbol): @define_op("read") @dataclass class Read(CustomOp): + memory: fx.Proxy elements_per_thread: Optional[Any] = None mapping: Optional[IndexMapping] = None @@ -937,10 +957,9 @@ def indexing_dims(self) -> list[IndexSymbol]: # TODO: This could contain ints. return list(self.memory_type.symbolic_shape) - @property - def type(self) -> "Register": + def infer_type(self): dtype = self.memory_type.dtype - return Register[*self.indexing_dims, dtype] + self.type = Register[*self.indexing_dims, dtype] @property def memory_type(self) -> "Memory": @@ -1052,12 +1071,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: captured_vars.append(nested_node) return captured_vars - @property - def type(self) -> Memory | Register | list[Memory | Register]: + def infer_type(self): res_types = [get_custom(x).type for x in self.init_args] if len(res_types) == 1: res_types = res_types[0] - return res_types + self.type = res_types def outputs(self, graph: fx.Graph) -> list[fx.Node]: for node in graph.nodes: @@ -1110,11 +1128,12 @@ def indexing_dims(self) -> list[IndexSymbol]: if self.mapping is not None: return list(self.mapping.input_shape) # TODO: This could contain ints. - return list(self.type.symbolic_shape) + return list(self.memory_type.symbolic_shape) - @property - def type(self) -> "Memory": - return get_custom(self.memory).type + def infer_type(self): + address_space = self.memory_type.address_space + dtype = self.memory_type.dtype + self.type = Memory[*self.indexing_dims, address_space, dtype] @property def memory_type(self) -> "Memory": @@ -1144,13 +1163,12 @@ class GetResult(CustomOp): value: fx.Node res_idx: int - @property - def type(self) -> "Memory": + def infer_type(self): src_type = get_custom(self.value).type if isinstance(src_type, list): - return src_type[self.res_idx] + self.type = src_type[self.res_idx] else: - return src_type + self.type = src_type @property def indexing_dims(self) -> list[IndexExpr]: @@ -1200,14 +1218,14 @@ class Extract(CustomOp): register_: fx.Proxy offset: IndexExpr | int - @property - def type(self) -> "Register": + def infer_type(self): # Intuition here is we are trying to extract an element # from fastest dim => we reduce the fastest dim. src_type = get_custom(self.register_).type # Return itself if just 0-D/1-D symbolic. if len(src_type.symbolic_shape) <= 1: - return src_type + self.type = src_type + return # Typically fastest dim is the last dimension, # If non-unit dim exists => non-unit dim is fastest dim. @@ -1220,7 +1238,7 @@ def type(self) -> "Register": dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0] dst_shape.remove(dim_to_remove) dst_type = Register[*dst_shape, src_type.dtype] - return dst_type + self.type = dst_type @define_op("extract_slice") @@ -1297,12 +1315,8 @@ def indexing_dims(self) -> list[IndexSymbol]: dst_indexing = [dim for dim in src_indexing if dim != self.dim] return dst_indexing - @property - def type(self) -> Memory: + def infer_type(self): if isinstance(self.arg, Sequence): - # Local import to break circular dep. - from ..wave.utils import all_equal - src_types = [get_custom(arg).type for arg in self.arg] ref_shape = src_types[0].symbolic_shape ref_dtype = src_types[0].dtype @@ -1318,7 +1332,7 @@ def type(self) -> Memory: src_type = get_custom(self.arg).type reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] dst_type = Register[*reduced_dims, src_type.dtype] - return dst_type + self.type = dst_type @property def num_reduction_dims(self) -> int: @@ -1376,10 +1390,9 @@ class CastOp(CustomOp, ABC): def indexing_dims(self) -> list[IndexSymbol]: return get_custom(self.arg).indexing_dims - @property - def type(self) -> Memory: + def infer_type(self): src_shape = get_custom(self.arg).type.symbolic_shape - return Register[*src_shape, self.dtype] + self.type = Register[*src_shape, self.dtype] @define_op("permute") @@ -1397,13 +1410,12 @@ class Permute(CustomOp, ABC): def indexing_dims(self) -> list[IndexExpr]: return self.target_shape - @property - def type(self) -> Register: + def infer_type(self): src_type = get_custom(self.arg).type assert set(src_type.symbolic_shape) == set( self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" - return Register[*self.target_shape, src_type.dtype] + self.type = Register[*self.target_shape, src_type.dtype] def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]: diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 55c1db85..ed777bc5 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -336,7 +336,12 @@ def _expand_reduction( # Add GetResult nodes for the corresponding dimensions reduction.graph.inserting_after(reduction.fx_node) new_node = GetResult(reduction.fx_node, len(new_output_args)) - new_node.add_to_graph(reduction.graph) + # Usually we would rely on infer_types inside add_to_graph to figure out + # the type of the new node. However, in this case, the logic to determine + # the type requires the reduction node to have its init_args set, which has + # not happened yet (it happens later). So instead, since we have access to + # arg, we just set the type directly. + new_node.add_to_graph(reduction.graph, arg.type) new_node.fx_node.name = get_expanded_name(new_node, dims) context[ (reduction, get_indexed_dims(dims, expand_dims), arg_idx) diff --git a/iree/turbine/kernel/wave/type_inference.py b/iree/turbine/kernel/wave/type_inference.py new file mode 100644 index 00000000..db574cfc --- /dev/null +++ b/iree/turbine/kernel/wave/type_inference.py @@ -0,0 +1,21 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..ops.wave_ops import * +from .._support.tracing import CapturedTrace +import torch.fx as fx +from ...support.logging import get_logger + +logger = get_logger("turbine.wave.type_inference") + + +def infer_types(trace: CapturedTrace | fx.Graph): + # Infer and set the types for all nodes in the graph. + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + custom.infer_type() + logger.debug(f"Setting type for {custom.fx_node} = {custom.type}") diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 07ca9ab1..7574f032 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -49,6 +49,7 @@ from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr +from .type_inference import infer_types import iree.turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, @@ -224,6 +225,9 @@ def _trace_and_get_kernel_signature( # Initialize Vector shapes self.hardware_constraints[0].subs_vector_shapes(idxc.subs) + # Do type inference. + infer_types(graph) + # Promote the placeholders to the appropriate address space. promote_placeholders(graph, self.constraints) hoist_allocs(graph) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index c4c02ccd..6a67cfb9 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -14,6 +14,7 @@ from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -86,6 +87,7 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() + infer_types(trace) promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) set_node_indices(trace, constraints) expand_graph(trace, constraints) @@ -171,6 +173,7 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") IndexingContext.current().finalize() + infer_types(trace) read_nodes = get_read_nodes(graph) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 4f00c54e..1c86de37 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -6,6 +6,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.index_sequence_analysis import ( set_node_indices, set_post_expansion_indices, @@ -69,6 +70,7 @@ def test_read_write_equal_sizes(): ): graph = read_write_same_size() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -150,6 +152,7 @@ def test_read_write(): ): graph = read_write_different_dims() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -227,6 +230,7 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -413,6 +417,7 @@ def test_batched_gemm(): ): graph = batched_gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -591,6 +596,7 @@ def test_gemm_non_direct_acc(): ): graph = gemm_non_direct_acc() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -657,6 +663,7 @@ def test_tiled_max(): ): graph = tiled_max() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -688,6 +695,7 @@ def test_gemm_reduction_expansion_only(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -791,6 +799,7 @@ def py_arithmetic_different_dims(): ): graph = py_arithmetic_different_dims() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -896,6 +905,7 @@ def test_chained_gemm_32x32x8(): ): graph = chained_gemm_32x32x8() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 36594f22..812edd6f 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -84,6 +85,7 @@ def test_gemm(): ): trace: CapturedTrace = gemm() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index a6a61a17..f74a8764 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -10,6 +10,7 @@ from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -87,6 +88,7 @@ def test_gemm(): trace: CapturedTrace = gemm() visualize = False IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index c3836f4f..f1f348a7 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -7,6 +7,7 @@ import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -67,6 +68,7 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() + infer_types(trace) promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) print_trace(trace, False) # CHECK: %a @@ -116,6 +118,7 @@ def test_read_write_equal_sizes_different_address_spaces(): ): trace: CapturedTrace = read_write_same_size_different_address_spaces() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) print_trace(trace, False) # CHECK: %a @@ -170,10 +173,11 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") read_nodes = get_read_nodes(graph) + IndexingContext.current().finalize() + infer_types(trace) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) hoist_allocs(trace) - IndexingContext.current().finalize() print_trace(trace, False) # Root graph: # CHECK: %a diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index afa6065b..2f7780bc 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -92,6 +93,7 @@ def test_gemm_pipelined(): ): trace: CapturedTrace = gemm_pipelined() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index d8728e3d..80f01963 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -29,6 +29,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads from iree.turbine.kernel.wave.scheduling.schedule import schedule_graph from iree.turbine.kernel.ops.wave_ops import get_custom @@ -277,6 +278,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True): trace: CapturedTrace = gemm() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/tests/kernel/wave/type_inference_test.py b/tests/kernel/wave/type_inference_test.py new file mode 100644 index 00000000..6ce7efa2 --- /dev/null +++ b/tests/kernel/wave/type_inference_test.py @@ -0,0 +1,199 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import logging +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.type_inference import infer_types +from iree.turbine.kernel.ops.wave_ops import get_custom + + +class TypeInferenceTest(unittest.TestCase): + def testAttentionInference(self): + shape = (8, 128, 128, 64, 256) + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = MMAType.F32_16x16x16_F16 + if mfma_variant == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: Mvec, N: Nvec}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave_trace_only(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[B, N, K, tkl.f16] + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[B, N, M, tkl.f32] + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write( + res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD + ) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + trace: CapturedTrace = base_attention() + IndexingContext.current().finalize() + infer_types(trace) + expected_type = { + "partial_sum": "Register[B, M].of(f32)", + "partial_max": "Register[B, M].of(f32)", + "acc": "Register[B, N, M].of(f32)", + "q": "Memory[B, M, K1].of(f16)", + "read": "Register[B, M, K1].of(f16)", + "k": "Memory[B, K2, K1].of(f16)", + "read_1": "Register[B, K2, K1].of(f16)", + "mma": "Register[B, K2, M].of(f32)", + "permute": "Register[B, M, K2].of(f32)", + "max_1": "Register[B, M].of(f32)", + "sub": "Register[B, M].of(f32)", + "exp2": "Register[B, M].of(f32)", + "sub_1": "Register[B, M, K2].of(f32)", + "exp2_1": "Register[B, M, K2].of(f32)", + "mul": "Register[B, M].of(f32)", + "sum_1": "Register[B, M].of(f32)", + "cast": "Register[B, M, K2].of(f16)", + "v": "Memory[B, N, K2].of(f16)", + "read_2": "Register[B, N, K2].of(f16)", + "mul_1": "Register[B, N, M].of(f32)", + "mma_1": "Register[B, N, M].of(f32)", + "c": "Memory[B, M, N].of(f32)", + "register_1": "Register[B, M].of(f32)", + "register_2": "Register[B, M].of(f32)", + "reduction": "[Register[B, M].of(f32), Register[B, M].of(f32), Register[B, N, M].of(f32)]", + "getitem": "Register[B, M].of(f32)", + "getitem_1": "Register[B, M].of(f32)", + "getitem_2": "Register[B, N, M].of(f32)", + "truediv": "Register[B, N, M].of(f32)", + "write": "Memory[B, N, M].of(f32)", + } + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + if custom.fx_node.name in expected_type: + assert str(custom.type) == expected_type[custom.fx_node.name] + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/kernel/wave/visualization_test.py b/tests/kernel/wave/visualization_test.py index a7e84b0d..04a1959e 100644 --- a/tests/kernel/wave/visualization_test.py +++ b/tests/kernel/wave/visualization_test.py @@ -13,6 +13,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import get_custom @@ -93,6 +94,7 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints)