From a2bb49da8401ac627964e5a3edf8ae699f30e156 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 12 Nov 2024 17:52:12 +0100 Subject: [PATCH] refac default device Signed-off-by: Ivan Butygin --- iree/turbine/kernel/wave/utils.py | 5 +++++ lit_tests/kernel/wave/codegen.py | 15 ++++++++------- tests/kernel/wave/wave_attention_test.py | 7 ++++--- tests/kernel/wave/wave_e2e_test.py | 21 +++++++++++---------- tests/kernel/wave/wave_gemm_test.py | 7 ++++--- 5 files changed, 32 insertions(+), 23 deletions(-) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index bcaa772f..4be34495 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -96,6 +96,11 @@ def run_test(func: Callable[[], None]) -> Callable[[], None]: return func +def get_default_run_config() -> dict[Any, Any]: + """Return default config for testing.""" + return {"backend": "rocm", "device": "hip", "target": "gfx942"} + + def print_trace(trace: CapturedTrace, custom_print: bool = True): """ Prints all subgraphs of a trace starting with the root graph. diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index fc5e482c..1aaab53e 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -8,6 +8,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.utils import ( run_test, + get_default_run_config, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) @@ -1508,7 +1509,7 @@ def test( res = tkw.sum(res, dim=N) tkw.write(res, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 128) a = torch.randn(shape, dtype=torch.float16) @@ -1584,7 +1585,7 @@ def test( res = tkw.sum([lhs, rhs], dim=N) tkw.write(res, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 128) a = torch.randn(shape, dtype=torch.float16) @@ -1656,7 +1657,7 @@ def repeat( result = repeat + repeat tkw.write(result, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -1744,7 +1745,7 @@ def repeat( tkw.write(repeat, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -1843,7 +1844,7 @@ def repeat( tkw.write(res_max, c, elements_per_thread=1) tkw.write(res_sum, d, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -1942,7 +1943,7 @@ def repeat( res_max, res_sum = repeat tkw.write(res_sum, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 1024) a = torch.randn(shape, dtype=torch.float32) @@ -2003,7 +2004,7 @@ def test( res = lhs + rhs tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() shape = (256, 128) a = torch.ones(shape, dtype=torch.float16) diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 792d9cff..9123bf4c 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -15,6 +15,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( + get_default_run_config, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) @@ -161,7 +162,7 @@ def repeat( GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 @@ -310,7 +311,7 @@ def repeat( GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 @@ -478,7 +479,7 @@ def repeat( GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 40c0a6f3..fea4117e 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -10,6 +10,7 @@ from iree.turbine.kernel.wave.wave_sim import wave_sim from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.iree_utils import generate_iree_ref +from iree.turbine.kernel.wave.utils import get_default_run_config import torch from numpy.testing import assert_allclose, assert_equal import pytest @@ -92,7 +93,7 @@ def test( res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() a = torch.randn(shape, dtype=torch.float16) b = torch.zeros(shape, dtype=torch.float16) @@ -149,7 +150,7 @@ def test( res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() a = torch.randn(shape, dtype=torch.float16) b = torch.zeros(shape, dtype=torch.float16) @@ -208,7 +209,7 @@ def test( res = tkw.read(a, mapping=mapping, elements_per_thread=ELEMS_PER_THREAD) tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() a = torch.randn(shape, dtype=torch.float16) b = torch.zeros(shape[::-1], dtype=torch.float16) @@ -266,7 +267,7 @@ def test( res = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) tkw.write(res, b, mapping=mapping, elements_per_thread=ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() a = torch.randn(shape, dtype=torch.float16) b = torch.zeros(shape[::-1], dtype=torch.float16) @@ -321,7 +322,7 @@ def test( res = tkw.sum(res, dim=N) tkw.write(res, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() torch.manual_seed(1) a = torch.randn(shape, dtype=torch.float16) @@ -393,7 +394,7 @@ def repeat( result = res_max / res_sum tkw.write(result, c, elements_per_thread=1) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() torch.manual_seed(1) a = torch.randn(shape, dtype=torch.float32) @@ -492,7 +493,7 @@ def test( res = tkw.read(a, mapping=mapping, elements_per_thread=ELEMS_PER_THREAD) tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() h_out = (h + 2 * padding - hf) // stride + 1 w_out = (w + 2 * padding - wf) // stride + 1 @@ -631,7 +632,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: out = torch.zeros_like(out_ref) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() with tk.gen.TestLaunchContext( { @@ -843,7 +844,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD ) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() run_bench = request.config.getoption("--runperf") dump_perf = request.config.getoption("--dump-perf-files-path") @@ -939,7 +940,7 @@ def test( res = tkw.cast(res, tkl.f16) tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD) - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() a = torch.randn(shape, dtype=torch.float32) b = torch.zeros(shape, dtype=torch.float16) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 7c512b24..35d989ab 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -14,6 +14,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( + get_default_run_config, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) @@ -146,7 +147,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 @@ -265,7 +266,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 @@ -380,7 +381,7 @@ def repeat( GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3