diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 05963a4b..1dad0160 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1400,18 +1400,3 @@ def type(self) -> Register: self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" return Register[*self.target_shape, src_type.dtype] - - @property - def index(self) -> Optional[dict[IndexSymbol, IndexSequence]]: - """ - Computes the permuted index based on the target shape. - """ - src_type = get_custom(self.arg).type - dim_map = { - tgt: src for src, tgt in zip(src_type.symbolic_shape, self.target_shape) - } - return {tgt: get_custom(self.arg).index[src] for tgt, src in dim_map.items()} - - @index.setter - def index(self, value: Any): - CustomOp.index.fset(self, value) diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 124e0948..1ecfa7b2 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -7,6 +7,7 @@ import logging import pytest import torch +import math import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl @@ -20,7 +21,7 @@ from iree.turbine.kernel.wave.constraints import MMAType import os import json -from torch.testing import assert_close +from torch.testing import assert_close, assert_allclose _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") @@ -194,3 +195,167 @@ def repeat( "chain_mmt", [q, k, v], [iree_ref], config, run_bench=run_bench ) assert_close(output, iree_ref) + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_attention")) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x16_F16, + ], +) +def testAttention( + shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: 16, N: 16}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, M: j, N: k}, outputs={B: i, N: k, M: j} + ) + + @tkw.wave(constraints) + def base_attention( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[B, N, K, tkl.f16] + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[B, N, M, tkl.f32] + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + ): + torch.manual_seed(0) + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) + output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape[3]) + # TODO: Add scaling of QK as part of kernel. + # TODO: Add variant of non-transposed V attention kernel. + mb = base_attention(q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output) + torch_ref = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None + ) + + if test_dump_generated_mlir: + filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + # TODO: Fix transposed writes to output. + assert_allclose(output.permute([0, 2, 1]), torch_ref)