diff --git a/iree-requirements-ci.txt b/iree-requirements-ci.txt index 7f165f1f..1df055dc 100644 --- a/iree-requirements-ci.txt +++ b/iree-requirements-ci.txt @@ -3,5 +3,5 @@ # more forgiving on the exact version. --find-links https://iree.dev/pip-release-links.html -iree-compiler==20240808.979 -iree-runtime==20240808.979 +iree-compiler==20240913.1015 +iree_compiler==20240913.1015 diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index a1883691..f6be531b 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -250,7 +250,7 @@ def compile_and_invoke( # TODO: More targets/backends support. if backend == "rocm": target = config["target"] - flags.append(f"--iree-rocm-target-chip={target}") + flags.append(f"--iree-hip-target={target}") if config.get("print_ir_after_all", False): flags.append("--mlir-print-ir-after-all") diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 80db8cc7..270d5e19 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -1,6 +1,8 @@ import shark_turbine.kernel as tk import shark_turbine.kernel.lang as tkl import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.wave.wave_sim import wave_sim +from shark_turbine.kernel.lang.global_symbols import * import torch from numpy.testing import assert_allclose, assert_equal import pytest @@ -9,7 +11,7 @@ import torch import json -_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) +_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 1)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") default_test_shapes = [(1, 128), (256, 64), (256, 128), (256, 256), (256, 1024)] @@ -408,3 +410,263 @@ def test( ): test(a, b) assert_allclose(b, expected) +@require_e2e +def test_im2col_mma(): + # igemm without final col2im + n, c, h, w = 1, 4, 9, 9 # Image. + nf, cf, hf, wf = 64, c, 2, 2 # Filters. + padding = 0 # TODO: only pad=0 is supported for now + stride = 1 + + x = torch.randn(n, c, h, w, dtype=torch.float16) + we = torch.randn(nf, cf, hf, wf, dtype=torch.float16) + + convRef = torch.nn.Conv2d(c, nf, hf, stride=stride, padding=padding, bias=False) + convRef.weight = torch.nn.Parameter(we) + out_ref = convRef(x).detach() + + sym = tkl.sym + N, C, H, W = sym.N, sym.C, sym.H, sym.W + NF, HF, WF = sym.NF, sym.HF, sym.WF + + H_OUT = (H + 2 * padding - HF) // stride + 1 + W_OUT = (W + 2 * padding - WF) // stride + 1 + SZ_OUT = H_OUT * W_OUT + + K = HF * WF * C + M = SZ_OUT * N + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + + x_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: i // SZ_OUT, + C: j // (HF * WF), + H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF, + W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF, + }, + outputs={M: i, K: j}, + ) + w_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF}, + outputs={NF: i, K: j}, + ) + + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + # BLOCK_K = tkl.sym.BLOCK_K + BLOCK_K = K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(NF, BLOCK_N)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + # vector_shapes={NF: 1, M: BLOCK_M, K: ELEMS_PER_THREAD}, + ) + ] + + def func( + x: tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16], + we: tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16], + out: tkl.Memory[M, NF, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, NF, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: + a_reg = tkw.read( + x, + mapping=x_mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + b_reg = tkw.read( + we, + mapping=w_mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, out, elements_per_thread=ELEMS_PER_THREAD) + + sim_func = wave_sim(constraints)(func) + gpu_func = tkw.wave(constraints)(func) + + h_out = (h + 2 * padding - hf) // stride + 1 + w_out = (w + 2 * padding - wf) // stride + 1 + res_shape = (h_out * w_out * n, nf) + out_ref = torch.zeros(res_shape, dtype=torch.float32) + sim_func(x, we, out_ref) + + out = torch.zeros_like(out_ref) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + with tk.gen.TestLaunchContext( + { + N: n, + C: c, + W: w, + H: h, + NF: nf, + WF: wf, + HF: hf, + BLOCK_M: 64, + BLOCK_N: 64, + ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + run=True, + run_config=config, + ): + gpu_func(x, we, out) + assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05) + + +@require_e2e +def test_igemm_conv(): + n, c, h, w = 1, 4, 5, 5 # Image. + nf, cf, hf, wf = 16, c, 2, 2 # Filters. + padding = 0 # TODO: only pad=0 is supported for now + stride = 1 + + torch.manual_seed(1) + x = torch.randn(n, c, h, w, dtype=torch.float16) + we = torch.randn(nf, cf, hf, wf, dtype=torch.float16) + + convRef = torch.nn.Conv2d(c, nf, hf, stride=stride, padding=padding, bias=False) + convRef.weight = torch.nn.Parameter(we) + out_ref = convRef(x).detach() + + sym = tkl.sym + N, C, H, W = sym.N, sym.C, sym.H, sym.W + NF, HF, WF = sym.NF, sym.HF, sym.WF + + H_OUT = (H + 2 * padding - HF) // stride + 1 + W_OUT = (W + 2 * padding - WF) // stride + 1 + SZ_OUT = H_OUT * W_OUT + + K = HF * WF * C + M = SZ_OUT * N + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + + x_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: i // SZ_OUT, + C: j // (HF * WF), + H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF, + W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF, + }, + outputs={M: i, K: j}, + ) + w_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF}, + outputs={NF: i, K: j}, + ) + out_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: i, NF: j}, + outputs={ + N: i // SZ_OUT, + NF: j, + H_OUT: (i % SZ_OUT) % W_OUT, + W_OUT: (i % SZ_OUT) // W_OUT, + }, + ) + + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(NF, BLOCK_N)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + ) + ] + + @tkw.wave(constraints) + def conv( + x: tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16], + we: tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16], + out: tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, NF, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: + a_reg = tkw.read( + x, + mapping=x_mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + b_reg = tkw.read( + we, + mapping=w_mapping, + elements_per_thread=ELEMS_PER_THREAD, + ) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write( + repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD + ) + + out = torch.zeros_like(out_ref) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + with tk.gen.TestLaunchContext( + { + N: n, + C: c, + W: w, + H: h, + NF: nf, + WF: wf, + HF: hf, + BLOCK_M: 16, + BLOCK_N: 16, + ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + run=True, + run_config=config, + ): + conv(x, we, out) + assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05) \ No newline at end of file diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 2c66130c..d2d3a5dd 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -10,7 +10,7 @@ import os import json -_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) +_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 1)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)]