Skip to content

Commit

Permalink
Update run configurations for gemm test
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Oct 8, 2024
1 parent f207ca5 commit 10cd2de
Showing 1 changed file with 50 additions and 28 deletions.
78 changes: 50 additions & 28 deletions tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@
# Whether to use scheduling group barriers (needs LLVM fix).
enable_scheduling_barriers = int(os.environ.get("WAVE_USE_SCHED_BARRIERS", 0))

default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)]

default_test_shapes = [
(2048, 10240, 1280, 128, 320, 32, 2, 2, 2, 2, 2, 2, 1, 1, 2),
(2048, 1280, 1280, 64, 64, 64, 2, 2, 1, 2, 1, 1, 1, 1, 2),
(2048, 1280, 5120, 128, 80, 128, 4, 1, 1, 4, 2, 2, 1, 1, 2),
(128, 1280, 2048, 64, 64, 128, 2, 2, 1, 8, 2, 2, 1, 1, 2),
(8192, 5120, 640, 128, 128, 32, 2, 2, 1, 4, 2, 2, 1, 1, 2),
]

perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only)

default_test_shapes += [
perf_test((1024, 5120, 640)),
perf_test((2048, 10240, 1280)),
perf_test((4096, 20480, 2560)),
]
default_test_shapes += [perf_test(x) for x in default_test_shapes]

user_specified_test_shapes = ""

Expand All @@ -52,7 +55,23 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
@pytest.mark.parametrize("enable_scheduling", [False, True])
def testGemm(shape: tuple[int], enable_scheduling: bool, request):
def testGemm(params: tuple[int], enable_scheduling: bool, request):
(
m,
n,
k,
block_m,
block_n,
block_k,
ratio_m,
ratio_n,
mma_units,
shared_units,
global_units,
delay_mma,
delay_shared,
delay_global,
) = params
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
# Input sizes
Expand All @@ -73,11 +92,13 @@ def testGemm(shape: tuple[int], enable_scheduling: bool, request):
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / ratio_n)]

constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(ratio_m, ratio_n, 1)
)
]

# Wave-level micro-kernel.
Expand Down Expand Up @@ -113,20 +134,20 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
M: shape[0],
N: shape[1],
K: shape[2],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
BLOCK_M: block_m,
BLOCK_N: block_n,
BLOCK_K: block_k,
M: m,
N: n,
K: k,
READ_SHARED_DELAY: delay_shared,
WRITE_SHARED_DELAY: delay_shared,
READ_GLOBAL_DELAY: delay_global,
WRITE_GLOBAL_DELAY: delay_global,
MMA_DELAY: delay_mma,
SHARED_MEMORY_UNITS: shared_units,
GLOBAL_MEMORY_UNITS: global_units,
MMA_UNITS: mma_units,
}
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
if run_bench:
Expand All @@ -147,12 +168,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
schedule=enable_scheduling,
use_scheduling_barriers=enable_scheduling_barriers,
):
a = torch.randn(shape[0], shape[2], dtype=torch.float16)
b = torch.randn(shape[1], shape[2], dtype=torch.float16)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
a = torch.randn(params.m, params.k, dtype=torch.float16)
b = torch.randn(params.n, params.k, dtype=torch.float16)
c = torch.zeros(params.m, params.n, dtype=torch.float32)
mb = gemm(a, b, c)

if test_dump_generated_mlir:
shape = [params.m, params.n, params.k]
filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())
Expand All @@ -162,6 +184,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
config["benchmark_results_file"] = os.path.join(
dump_perf, "iree_" + perf_filename
)
iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32)
iree_ref = torch.zeros(params.m, params.n, dtype=torch.float32)
generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench)
assert_close(c, iree_ref)

0 comments on commit 10cd2de

Please sign in to comment.