diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 89b96ccf..6f6057ce 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -556,6 +556,16 @@ def vector_shapes(self) -> dict[IndexSymbol, int]: def vector_shapes(self, value: dict[IndexSymbol, int]): self.fx_node.vector_shapes = value + @property + def type(self) -> Any: + if hasattr(self.fx_node, "type"): + return self.fx_node.type + return None + + @type.setter + def type(self, value: Any): + self.fx_node.type = value + def align_index(self, constraints: list["Constraint"]) -> None: """ Align index to WG/Tile sizes. @@ -602,10 +612,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - @property - def type(self) -> Memory: - lhs_type = get_custom(self.lhs).type - rhs_type = get_custom(self.rhs).type + def infer_type(self, lhs_type: Register, rhs_type: Register) -> Register: has_same_type = has_same_custom_type(lhs_type, rhs_type) if has_same_type: return lhs_type @@ -637,9 +644,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def py_operator(self) -> str: return self.tkw_op_name - @property - def type(self) -> Memory: - src_type = get_custom(self.arg).type + def infer_type(self, src_type: Register) -> Register: return src_type @@ -868,10 +873,6 @@ def rhs_type(self) -> Memory: def acc_type(self) -> Memory: return get_custom(self.acc).type - @property - def type(self) -> Memory: - return self.acc_type - def operand_index( self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr] ) -> dict[IndexSymbol, IndexSequence]: @@ -925,6 +926,7 @@ def reduction_dim(self, value: IndexSymbol): @define_op("read") @dataclass class Read(CustomOp): + memory: fx.Proxy elements_per_thread: Optional[Any] = None mapping: Optional[IndexMapping] = None @@ -937,10 +939,12 @@ def indexing_dims(self) -> list[IndexSymbol]: # TODO: This could contain ints. return list(self.memory_type.symbolic_shape) - @property - def type(self) -> "Register": - dtype = self.memory_type.dtype - return Register[*self.indexing_dims, dtype] + def infer_type(self, memory_type: Memory) -> "Register": + dtype = memory_type.dtype + shape = memory_type.symbolic_shape + if self.mapping is not None: + shape = self.mapping.output_shape + return Register[*shape, dtype] @property def memory_type(self) -> "Memory": @@ -1052,9 +1056,7 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]: captured_vars.append(nested_node) return captured_vars - @property - def type(self) -> Memory | Register | list[Memory | Register]: - res_types = [get_custom(x).type for x in self.init_args] + def infer_type(self, res_types: list[Register]) -> Register | list[Register]: if len(res_types) == 1: res_types = res_types[0] return res_types @@ -1112,9 +1114,13 @@ def indexing_dims(self) -> list[IndexSymbol]: # TODO: This could contain ints. return list(self.type.symbolic_shape) - @property - def type(self) -> "Memory": - return get_custom(self.memory).type + def infer_type(self, memory_type: Memory) -> "Memory": + dtype = memory_type.dtype + shape = memory_type.symbolic_shape + address_space = memory_type.address_space + if self.mapping is not None: + shape = self.mapping.input_shape + return Memory[*shape, address_space, dtype] @property def memory_type(self) -> "Memory": @@ -1144,9 +1150,7 @@ class GetResult(CustomOp): value: fx.Node res_idx: int - @property - def type(self) -> "Memory": - src_type = get_custom(self.value).type + def infer_type(self, src_type: Register, idx: int) -> "Memory": if isinstance(src_type, list): return src_type[self.res_idx] else: @@ -1200,8 +1204,7 @@ class Extract(CustomOp): register_: fx.Proxy offset: IndexExpr | int - @property - def type(self) -> "Register": + def infer_type(self, src_type) -> "Register": # Intuition here is we are trying to extract an element # from fastest dim => we reduce the fastest dim. src_type = get_custom(self.register_).type @@ -1297,13 +1300,8 @@ def indexing_dims(self) -> list[IndexSymbol]: dst_indexing = [dim for dim in src_indexing if dim != self.dim] return dst_indexing - @property - def type(self) -> Memory: - if isinstance(self.arg, Sequence): - # Local import to break circular dep. - from ..wave.utils import all_equal - - src_types = [get_custom(arg).type for arg in self.arg] + def infer_type(self, src_types: list[Register] | Register) -> Register: + if isinstance(src_types, Sequence): ref_shape = src_types[0].symbolic_shape ref_dtype = src_types[0].dtype if not all( @@ -1315,7 +1313,7 @@ def type(self) -> Memory: ) src_type = src_types[0] else: - src_type = get_custom(self.arg).type + src_type = src_types reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] dst_type = Register[*reduced_dims, src_type.dtype] return dst_type @@ -1376,9 +1374,8 @@ class CastOp(CustomOp, ABC): def indexing_dims(self) -> list[IndexSymbol]: return get_custom(self.arg).indexing_dims - @property - def type(self) -> Memory: - src_shape = get_custom(self.arg).type.symbolic_shape + def infer_type(self, src_type: Register) -> Register: + src_shape = src_type.symbolic_shape return Register[*src_shape, self.dtype] @@ -1397,9 +1394,7 @@ class Permute(CustomOp, ABC): def indexing_dims(self) -> list[IndexExpr]: return self.target_shape - @property - def type(self) -> Register: - src_type = get_custom(self.arg).type + def infer_type(self, src_type: Register) -> Register: assert set(src_type.symbolic_shape) == set( self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index 39ba9b43..e3181137 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -87,8 +87,10 @@ def emit_sources_reduction( binary_fn: Callable, src: list[fx.Node], graph: fx.Graph ) -> fx.Node: init = src[0] + op_type = init.type for i in range(1, len(src)): init = get_graph_node(binary_fn(init, src[i]), graph) + init.type = op_type init.index = src[0].index return init @@ -97,9 +99,12 @@ def emit_local_reduction( binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int ) -> fx.Node: init = get_graph_node(Extract(src, [0]), graph) + init.type = get_custom(init).infer_type(src.type) for i in range(1, local_reduction_size): cur_slice = get_graph_node(Extract(src, [i]), graph) + cur_slice.type = get_custom(cur_slice).infer_type(src.type) init = get_graph_node(binary_fn(init, cur_slice), graph) + init.type = cur_slice.type return init @@ -117,6 +122,7 @@ def emit_global_reduction( shuffle_val = ShuffleOp(init, cluster_stride, subgroup_size) shuffle_node = get_graph_node(shuffle_val, graph) init = get_graph_node(binary_fn(init, shuffle_node), graph) + init.type = src.type cluster_stride <<= 1 return init diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 55c1db85..0c4299e2 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -338,6 +338,7 @@ def _expand_reduction( new_node = GetResult(reduction.fx_node, len(new_output_args)) new_node.add_to_graph(reduction.graph) new_node.fx_node.name = get_expanded_name(new_node, dims) + new_node.type = arg.type context[ (reduction, get_indexed_dims(dims, expand_dims), arg_idx) ] = new_node diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 0948432d..8f3e9712 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -125,6 +125,7 @@ def has_strided_access(node: fx.Node) -> bool: ) for j, dim in enumerate(custom.register_type.symbolic_shape) } + write.type = custom.memory.type custom.graph.erase_node(operator) diff --git a/iree/turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py index 1092b2c5..f1d180a2 100644 --- a/iree/turbine/kernel/wave/minimize_global_loads.py +++ b/iree/turbine/kernel/wave/minimize_global_loads.py @@ -14,7 +14,7 @@ from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr from ..ops.wave_ops import Read, Write, get_custom from ..lang.global_symbols import * -from .utils import delinearize_index, DCE, subs_idxc, ceildiv +from .utils import delinearize_index, DCE, subs_idxc, ceildiv, memory_to_register from math import prod import torch.fx as fx from collections import defaultdict @@ -140,6 +140,7 @@ def add_optimized_nodes( load_elems_per_thread, materialized_shape, ) + read.type = memory_to_register(memory.type) for custom_user in custom.users: if ( isinstance(custom_user, Write) @@ -149,6 +150,7 @@ def add_optimized_nodes( read, custom_user.memory, load_elems_per_thread ).add_to_graph(custom.graph) write.index = read.index + write.type = custom_user.type optimized_writes[custom_user.memory].append(write) break return optimized_writes diff --git a/iree/turbine/kernel/wave/promotion.py b/iree/turbine/kernel/wave/promotion.py index 3711436f..56b0c9cd 100644 --- a/iree/turbine/kernel/wave/promotion.py +++ b/iree/turbine/kernel/wave/promotion.py @@ -49,6 +49,9 @@ def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate): ).add_to_graph(custom_node.graph) custom_read = get_custom(promoted_read) custom_read.write_dependency = [promoted_write] + custom_read.type = custom_node.type + custom_write = get_custom(promoted_write) + custom_write.type = allocate_node.type custom_node.memory_type.address_space = GLOBAL_ADDRESS_SPACE diff --git a/iree/turbine/kernel/wave/type_inference.py b/iree/turbine/kernel/wave/type_inference.py new file mode 100644 index 00000000..c276052f --- /dev/null +++ b/iree/turbine/kernel/wave/type_inference.py @@ -0,0 +1,92 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..ops.wave_ops import * +from .._support.tracing import CapturedTrace +import torch.fx as fx +from typing import Sequence +from ...support.logging import get_logger + +logger = get_logger("turbine.wave.type_inference") + + +class TypeInferer: + def __init__(self): + self.index = 0 + self.symbol_table: dict[CustomOp, Memory | Register | list[Register]] = {} + + def get_type(self, op: fx.Node) -> Memory | Register | list[Register]: + custom = get_custom(op) + custom_type = self.symbol_table.get(custom, None) + if custom_type is None and custom not in self.symbol_table: + raise ValueError(f"No type found for {op}") + return custom_type + + def infer_types(self, op: CustomOp): + match op: + case BinaryPyOp(): + s = self.get_type(op.lhs) + t = self.get_type(op.rhs) + self.symbol_table[op] = op.infer_type(s, t) + case GetResult(): + s = self.get_type(op.value) + self.symbol_table[op] = op.infer_type(s, op.res_idx) + if self.symbol_table[op] is None: + breakpoint() + case Read() | Write(): + s = self.get_type(op.memory) + self.symbol_table[op] = op.infer_type(s) + case MMA(): + s = self.get_type(op.lhs) + t = self.get_type(op.rhs) + u = self.get_type(op.acc) + self.symbol_table[op] = u + case Placeholder() | NewRegister(): + self.symbol_table[op] = op.type + case Reduction(): + s = [] + for init_arg in op.init_args: + s.append(self.get_type(init_arg)) + self.symbol_table[op] = op.infer_type(s) + case ReduceOp(): + args = op.arg + if not isinstance(op.arg, Sequence): + args = [op.arg] + s = [] + for arg in args: + s.append(self.get_type(arg)) + self.symbol_table[op] = op.infer_type(s) + case CastOp() | Permute() | UnaryPyOp(): + s = self.get_type(op.arg) + self.symbol_table[op] = op.infer_type(s) + case Output(): + s = [] + for ret_vals in op.return_vals: + if ret_vals is None: + s = None + break + if not isinstance(ret_vals, Sequence): + ret_vals = [ret_vals] + for ret_val in ret_vals: + s.append(self.get_type(ret_val)) + self.symbol_table[op] = s + return + + +def infer_types(trace: CapturedTrace | fx.Graph): + inferer = TypeInferer() + # First, infer the types for all the nodes. + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + inferer.infer_types(custom) + # Then, set the types. + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + if not isinstance(custom, (Placeholder, NewRegister)): + custom.type = inferer.get_type(custom.fx_node) + logger.debug(f"Setting type for {custom.fx_node} = {custom.type}") diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index bcaa772f..6bd35fa3 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -26,6 +26,8 @@ GetResult, IterArg, Reshape, + Memory, + Register, ) from .constraints import ( Constraint, @@ -136,8 +138,8 @@ def is_removable_operator(node: fx.Node) -> bool: custom = get_custom(node) idxc = IndexingContext.current() is_global_write = isinstance(custom, Write) and ( - custom.type.address_space.subs(idxc.subs) == GLOBAL_ADDRESS_SPACE - or custom.type.address_space.subs(idxc.subs) + custom.memory_type.address_space.subs(idxc.subs) == GLOBAL_ADDRESS_SPACE + or custom.memory_type.address_space.subs(idxc.subs) == tkl.AddressSpace.GLOBAL_MEMORY.value ) @@ -824,3 +826,9 @@ def all_equal(input_list: list[Any]) -> bool: if len(input_list) == 0: return True return all(elem == input_list[0] for elem in input_list) + + +def memory_to_register(memory_type: Memory) -> Register: + dtype = memory_type.dtype + shape = memory_type.symbolic_shape + return Register[*shape, dtype] diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 07ca9ab1..7574f032 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -49,6 +49,7 @@ from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr +from .type_inference import infer_types import iree.turbine.kernel.lang as tkl from .._support.tracing import ( CapturedTrace, @@ -224,6 +225,9 @@ def _trace_and_get_kernel_signature( # Initialize Vector shapes self.hardware_constraints[0].subs_vector_shapes(idxc.subs) + # Do type inference. + infer_types(graph) + # Promote the placeholders to the appropriate address space. promote_placeholders(graph, self.constraints) hoist_allocs(graph) diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index c4c02ccd..6a67cfb9 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -14,6 +14,7 @@ from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -86,6 +87,7 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() + infer_types(trace) promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) set_node_indices(trace, constraints) expand_graph(trace, constraints) @@ -171,6 +173,7 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") IndexingContext.current().finalize() + infer_types(trace) read_nodes = get_read_nodes(graph) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 4f00c54e..1c86de37 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -6,6 +6,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.index_sequence_analysis import ( set_node_indices, set_post_expansion_indices, @@ -69,6 +70,7 @@ def test_read_write_equal_sizes(): ): graph = read_write_same_size() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -150,6 +152,7 @@ def test_read_write(): ): graph = read_write_different_dims() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -227,6 +230,7 @@ def test_gemm(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -413,6 +417,7 @@ def test_batched_gemm(): ): graph = batched_gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -591,6 +596,7 @@ def test_gemm_non_direct_acc(): ): graph = gemm_non_direct_acc() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -657,6 +663,7 @@ def test_tiled_max(): ): graph = tiled_max() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -688,6 +695,7 @@ def test_gemm_reduction_expansion_only(): ): graph = gemm() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -791,6 +799,7 @@ def py_arithmetic_different_dims(): ): graph = py_arithmetic_different_dims() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) @@ -896,6 +905,7 @@ def test_chained_gemm_32x32x8(): ): graph = chained_gemm_32x32x8() IndexingContext.current().finalize() + infer_types(graph) set_node_indices(graph, constraints) expand_graph(graph, constraints) set_post_expansion_indices(graph, constraints) diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 36594f22..812edd6f 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -84,6 +85,7 @@ def test_gemm(): ): trace: CapturedTrace = gemm() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index a6a61a17..3de0fade 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -10,11 +10,12 @@ from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.ops.wave_ops import * -from iree.turbine.kernel.wave.utils import run_test, print_trace +from iree.turbine.kernel.wave.utils import run_test, print_trace, memory_to_register from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads from iree.turbine.kernel.wave.visualization import visualize_graph from iree.turbine.kernel.wave.shared_memory_indexing import ( @@ -87,6 +88,7 @@ def test_gemm(): trace: CapturedTrace = gemm() visualize = False IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index c3836f4f..f1f348a7 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -7,6 +7,7 @@ import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -67,6 +68,7 @@ def test_read_write_equal_sizes(): graph: fx.Graph = trace.get_root_graph() read_node = get_read_nodes(graph)[0] IndexingContext.current().finalize() + infer_types(trace) promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) print_trace(trace, False) # CHECK: %a @@ -116,6 +118,7 @@ def test_read_write_equal_sizes_different_address_spaces(): ): trace: CapturedTrace = read_write_same_size_different_address_spaces() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) print_trace(trace, False) # CHECK: %a @@ -170,10 +173,11 @@ def test_gemm(): trace: CapturedTrace = gemm() graph: fx.Graph = trace.get_subgraph("region_0") read_nodes = get_read_nodes(graph) + IndexingContext.current().finalize() + infer_types(trace) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) hoist_allocs(trace) - IndexingContext.current().finalize() print_trace(trace, False) # Root graph: # CHECK: %a diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index afa6065b..2f7780bc 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.wave.promotion import promote_placeholders from iree.turbine.kernel.wave.hoisting import hoist_allocs from iree.turbine.kernel.wave.expansion import expand_graph +from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext @@ -92,6 +93,7 @@ def test_gemm_pipelined(): ): trace: CapturedTrace = gemm_pipelined() IndexingContext.current().finalize() + infer_types(trace) promote_placeholders(trace, constraints) hoist_allocs(trace) set_node_indices(trace, constraints) diff --git a/tests/kernel/wave/type_inference_test.py b/tests/kernel/wave/type_inference_test.py new file mode 100644 index 00000000..6ce7efa2 --- /dev/null +++ b/tests/kernel/wave/type_inference_test.py @@ -0,0 +1,199 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import unittest +import logging +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel._support.tracing import CapturedTrace +from iree.turbine.kernel._support.indexing import IndexingContext +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.type_inference import infer_types +from iree.turbine.kernel.ops.wave_ops import get_custom + + +class TypeInferenceTest(unittest.TestCase): + def testAttentionInference(self): + shape = (8, 128, 128, 64, 256) + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = MMAType.F32_16x16x16_F16 + if mfma_variant == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: Mvec, N: Nvec}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave_trace_only(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[B, N, K, tkl.f16] + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[B, N, M, tkl.f32] + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write( + res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD + ) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + trace: CapturedTrace = base_attention() + IndexingContext.current().finalize() + infer_types(trace) + expected_type = { + "partial_sum": "Register[B, M].of(f32)", + "partial_max": "Register[B, M].of(f32)", + "acc": "Register[B, N, M].of(f32)", + "q": "Memory[B, M, K1].of(f16)", + "read": "Register[B, M, K1].of(f16)", + "k": "Memory[B, K2, K1].of(f16)", + "read_1": "Register[B, K2, K1].of(f16)", + "mma": "Register[B, K2, M].of(f32)", + "permute": "Register[B, M, K2].of(f32)", + "max_1": "Register[B, M].of(f32)", + "sub": "Register[B, M].of(f32)", + "exp2": "Register[B, M].of(f32)", + "sub_1": "Register[B, M, K2].of(f32)", + "exp2_1": "Register[B, M, K2].of(f32)", + "mul": "Register[B, M].of(f32)", + "sum_1": "Register[B, M].of(f32)", + "cast": "Register[B, M, K2].of(f16)", + "v": "Memory[B, N, K2].of(f16)", + "read_2": "Register[B, N, K2].of(f16)", + "mul_1": "Register[B, N, M].of(f32)", + "mma_1": "Register[B, N, M].of(f32)", + "c": "Memory[B, M, N].of(f32)", + "register_1": "Register[B, M].of(f32)", + "register_2": "Register[B, M].of(f32)", + "reduction": "[Register[B, M].of(f32), Register[B, M].of(f32), Register[B, N, M].of(f32)]", + "getitem": "Register[B, M].of(f32)", + "getitem_1": "Register[B, M].of(f32)", + "getitem_2": "Register[B, N, M].of(f32)", + "truediv": "Register[B, N, M].of(f32)", + "write": "Memory[B, N, M].of(f32)", + } + for subgraph in trace.region_graph.subgraphs.values(): + for node in subgraph.nodes: + custom = get_custom(node) + if custom.fx_node.name in expected_type: + assert str(custom.type) == expected_type[custom.fx_node.name] + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()