Skip to content

Commit

Permalink
Add support for dynamic parallel dims in GEMMs
Browse files Browse the repository at this point in the history
This PR adds tests for dynamic M and N dims in
GEMMs. This works out of the box for the most part
and just requires moving the align index pass
after scheduling.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Nov 15, 2024
1 parent d8ce8d2 commit 322edaa
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 6 deletions.
10 changes: 5 additions & 5 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,6 @@ def _trace_and_get_kernel_signature(
# Partition strided operators.
partition_strided_operators(graph, self.constraints)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
align_index_sizes(graph, self.constraints)

# Decompose reduce Ops.
decompose_reduce_ops(graph, self.constraints, idxc.subs)

Expand All @@ -278,6 +273,11 @@ def _trace_and_get_kernel_signature(
use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False)
schedule_graph(graph, self.constraints, use_scheduling_barriers)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
align_index_sizes(graph, self.constraints)

# Add shared memory barriers.
add_shared_memory_barriers(graph)

Expand Down
96 changes: 96 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,102 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-8: amdgpu.mfma


@run_test
def test_dynamic_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def dynamic_gemm_pipelined(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
K: 128,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 8,
SHUFFLE_UNITS: 8,
},
canonicalize=True,
schedule=True,
use_scheduling_barriers=True,
dynamic_symbols=(M, N),
dynamic_symbols_map={M: 64, N: 128},
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(dynamic_gemm_pipelined(a, b, c).module_op)

# CHECK: func.func @dynamic_gemm_pipelined
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: scf.for
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-3: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-1: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-1: scf.yield
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-8: vector.load
# CHECK-COUNT-8: amdgpu.mfma


# This test is used to check three things
# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works
# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape).
Expand Down
28 changes: 27 additions & 1 deletion tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import json
from torch.testing import assert_close
from enum import Enum

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
Expand Down Expand Up @@ -57,9 +58,16 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
return default_test_shapes[test_name]


class Dims:
M = 0
N = 1
MN = 2


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
@pytest.mark.parametrize("enable_scheduling", [False, True])
@pytest.mark.parametrize("dynamic_dims", [None, Dims.M, Dims.N, Dims.MN])
@pytest.mark.parametrize(
"mfma_variant",
[
Expand All @@ -68,7 +76,11 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
],
)
def testGemm(
shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request
shape: tuple[int],
enable_scheduling: bool,
dynamic_dims: Dims,
mfma_variant: MMAType,
request,
):
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
Expand Down Expand Up @@ -161,6 +173,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
dump_perf, "tk_" + perf_filename
)

dynamic_symbols = []
dynamic_symbols_map = {}
match dynamic_dims:
case Dims.M | Dims.MN:
dynamic_symbols_map[M] = hyperparams[M]
dynamic_symbols.append(M)
del hyperparams[M]
case Dims.N | Dims.MN:
dynamic_symbols_map[N] = hyperparams[N]
dynamic_symbols.append(N)
del hyperparams[N]

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
Expand All @@ -169,6 +193,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
run_config=config,
schedule=enable_scheduling,
use_scheduling_barriers=enable_scheduling_barriers,
dynamic_symbols=dynamic_symbols,
dynamic_symbols_map=dynamic_symbols_map,
):
a = torch.randn(shape[0], shape[2], dtype=torch.float16)
b = torch.randn(shape[1], shape[2], dtype=torch.float16)
Expand Down

0 comments on commit 322edaa

Please sign in to comment.