Skip to content

Commit

Permalink
Update IREE Requirements
Browse files Browse the repository at this point in the history
Signed-off-by: erman-gurses <[email protected]>
  • Loading branch information
erman-gurses committed Sep 18, 2024
1 parent 7300a8d commit 69ea361
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 5 deletions.
4 changes: 2 additions & 2 deletions iree-requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# more forgiving on the exact version.

--find-links https://iree.dev/pip-release-links.html
iree-compiler==20240808.979
iree-runtime==20240808.979
iree-compiler==20240913.1015
iree_compiler==20240913.1015
2 changes: 1 addition & 1 deletion shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def compile_and_invoke(
# TODO: More targets/backends support.
if backend == "rocm":
target = config["target"]
flags.append(f"--iree-rocm-target-chip={target}")
flags.append(f"--iree-hip-target={target}")

if config.get("print_ir_after_all", False):
flags.append("--mlir-print-ir-after-all")
Expand Down
264 changes: 263 additions & 1 deletion tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import shark_turbine.kernel as tk
import shark_turbine.kernel.lang as tkl
import shark_turbine.kernel.wave as tkw
from shark_turbine.kernel.wave.wave_sim import wave_sim
from shark_turbine.kernel.lang.global_symbols import *
import torch
from numpy.testing import assert_allclose, assert_equal
import pytest
Expand All @@ -9,7 +11,7 @@
import torch
import json

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 1))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
default_test_shapes = [(1, 128), (256, 64), (256, 128), (256, 256), (256, 1024)]

Expand Down Expand Up @@ -408,3 +410,263 @@ def test(
):
test(a, b)
assert_allclose(b, expected)
@require_e2e
def test_im2col_mma():
# igemm without final col2im
n, c, h, w = 1, 4, 9, 9 # Image.
nf, cf, hf, wf = 64, c, 2, 2 # Filters.
padding = 0 # TODO: only pad=0 is supported for now
stride = 1

x = torch.randn(n, c, h, w, dtype=torch.float16)
we = torch.randn(nf, cf, hf, wf, dtype=torch.float16)

convRef = torch.nn.Conv2d(c, nf, hf, stride=stride, padding=padding, bias=False)
convRef.weight = torch.nn.Parameter(we)
out_ref = convRef(x).detach()

sym = tkl.sym
N, C, H, W = sym.N, sym.C, sym.H, sym.W
NF, HF, WF = sym.NF, sym.HF, sym.WF

H_OUT = (H + 2 * padding - HF) // stride + 1
W_OUT = (W + 2 * padding - WF) // stride + 1
SZ_OUT = H_OUT * W_OUT

K = HF * WF * C
M = SZ_OUT * N

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j // (HF * WF),
H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF},
outputs={NF: i, K: j},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
# BLOCK_K = tkl.sym.BLOCK_K
BLOCK_K = K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
# vector_shapes={NF: 1, M: BLOCK_M, K: ELEMS_PER_THREAD},
)
]

def func(
x: tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16],
we: tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16],
out: tkl.Memory[M, NF, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, NF, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
a_reg = tkw.read(
x,
mapping=x_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
b_reg = tkw.read(
we,
mapping=w_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, out, elements_per_thread=ELEMS_PER_THREAD)

sim_func = wave_sim(constraints)(func)
gpu_func = tkw.wave(constraints)(func)

h_out = (h + 2 * padding - hf) // stride + 1
w_out = (w + 2 * padding - wf) // stride + 1
res_shape = (h_out * w_out * n, nf)
out_ref = torch.zeros(res_shape, dtype=torch.float32)
sim_func(x, we, out_ref)

out = torch.zeros_like(out_ref)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

with tk.gen.TestLaunchContext(
{
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 64,
BLOCK_N: 64,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
run=True,
run_config=config,
):
gpu_func(x, we, out)
assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05)


@require_e2e
def test_igemm_conv():
n, c, h, w = 1, 4, 5, 5 # Image.
nf, cf, hf, wf = 16, c, 2, 2 # Filters.
padding = 0 # TODO: only pad=0 is supported for now
stride = 1

torch.manual_seed(1)
x = torch.randn(n, c, h, w, dtype=torch.float16)
we = torch.randn(nf, cf, hf, wf, dtype=torch.float16)

convRef = torch.nn.Conv2d(c, nf, hf, stride=stride, padding=padding, bias=False)
convRef.weight = torch.nn.Parameter(we)
out_ref = convRef(x).detach()

sym = tkl.sym
N, C, H, W = sym.N, sym.C, sym.H, sym.W
NF, HF, WF = sym.NF, sym.HF, sym.WF

H_OUT = (H + 2 * padding - HF) // stride + 1
W_OUT = (W + 2 * padding - WF) // stride + 1
SZ_OUT = H_OUT * W_OUT

K = HF * WF * C
M = SZ_OUT * N

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j // (HF * WF),
H: (i % SZ_OUT) % W_OUT * stride + (j % (HF * WF)) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j % (HF * WF)) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j // (HF * WF), HF: j % WF, WF: (j % (HF * WF)) // WF},
outputs={NF: i, K: j},
)
out_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, NF: j},
outputs={
N: i // SZ_OUT,
NF: j,
H_OUT: (i % SZ_OUT) % W_OUT,
W_OUT: (i % SZ_OUT) // W_OUT,
},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
)
]

@tkw.wave(constraints)
def conv(
x: tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16],
we: tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16],
out: tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, NF, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
a_reg = tkw.read(
x,
mapping=x_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
b_reg = tkw.read(
we,
mapping=w_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(
repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD
)

out = torch.zeros_like(out_ref)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

with tk.gen.TestLaunchContext(
{
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 16,
BLOCK_N: 16,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
run=True,
run_config=config,
):
conv(x, we, out)
assert_allclose(out, out_ref, rtol=1e-05, atol=1e-05)
2 changes: 1 addition & 1 deletion tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import json

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 1))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
default_test_shapes = [(1024, 5120, 640), (2048, 10240, 1280), (4096, 20480, 2560)]

Expand Down

0 comments on commit 69ea361

Please sign in to comment.