Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Add support for multiple/local reduceOp #234

Merged
merged 3 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,19 +1236,41 @@ class ReduceOp(CustomOp, ABC):
dim: which symbolic dim to reduce.
"""

arg: fx.Node
arg: fx.Node | list[fx.Node]
init: fx.Node = None
dim: Optional[Any] = None

@property
def indexing_dims(self) -> list[IndexSymbol]:
src_indexing = get_custom(self.arg).indexing_dims
# Local import to break circular dep.
from ..wave.utils import all_equal

if isinstance(self.arg, Sequence):
src_indexings = [get_custom(arg).indexing_dims for arg in self.arg]
if not all_equal(src_indexings):
raise NotImplementedError(
"NYI: Only support case where all inputs to ReduceOp to have same indexing dim."
)
src_indexing = src_indexings[0]
else:
src_indexing = get_custom(self.arg).indexing_dims
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
src_type = get_custom(self.arg).type
if isinstance(self.arg, Sequence):
# Local import to break circular dep.
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
if not all_equal(src_types):
raise NotImplementedError(
"NYI: Only support case where all inputs to ReduceOp to have same type."
)
src_type = src_types[0]
else:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type
Expand Down
43 changes: 36 additions & 7 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Reduction,
)

from .utils import DCE, subs_idxc
from .utils import DCE, subs_idxc, all_equal
import torch.fx as fx
import math
from typing import Callable
Expand All @@ -37,6 +37,16 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph):
return custom


def emit_sources_reduction(
binary_fn: Callable, src: list[fx.Node], graph: fx.Graph
) -> fx.Node:
init = src[0]
for i in range(1, len(src)):
init = get_graph_node(binary_fn(init, src[i]), graph)
init.index = src[0].index
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
return init


def emit_local_reduction(
binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int
) -> fx.Node:
Expand Down Expand Up @@ -67,11 +77,12 @@ def decompose_reduce_ops(
):
"""
The lowering for multi_reduction is done in two steps:
1. Local Reduce: Each thread reduces all elements carried by it along
1. Source Reduce: Each thread reduce locally all it's sources.
2. Local Reduce: Each thread reduces all elements carried by it along
the reduction dimensions.
2. Thread Reduce: Each thread reduces result of step 1 across threads
3. Thread Reduce: Each thread reduces result of step 2 across threads
by doing a butterfly shuffle.
3. Accumulator Reduce: Each thread reduces it's intermediate reduced
4. Accumulator Reduce: Each thread reduces it's intermediate reduced
results with the accumulator it holds.
"""
# Get reducte nodes.
Expand All @@ -98,19 +109,37 @@ def decompose_reduce_ops(
raise ValueError(
"No reduction dim specified, please specify a reduction dim."
)
if not isinstance(reduction_src, (list, tuple)):
reduction_src = [reduction_src]

# Local Reduce
if reduction_dim is not get_custom(custom.arg).type.symbolic_shape[-1]:
src_fastest_dims = [
get_custom(arg).type.symbolic_shape[-1] for arg in reduction_src
]
if not all_equal(src_fastest_dims):
raise NotImplementedError(
"NYI: Expect all reduce_src to have same fastest dim."
)
if reduction_dim is not src_fastest_dims[0]:
raise NotImplementedError(
"Only implemented reduction on fastest dimension."
)

get_thread_shape = lambda index: max(
subs_idxc(x.size) for x in index.values()
)
local_reduction_size = get_thread_shape(get_custom(custom.arg).index)
local_reduce_sizes = [
get_thread_shape(get_custom(arg).index) for arg in reduction_src
]
if not all_equal(local_reduce_sizes):
raise NotImplementedError(
"NYI: Expect all reduce_src to have same local reduce size."
)
src_reduction = emit_sources_reduction(
binary_fn, reduction_src, custom.graph
)
local_reduction = emit_local_reduction(
binary_fn, reduction_src, custom.graph, local_reduction_size
binary_fn, src_reduction, custom.graph, local_reduce_sizes[0]
)

# Global Reduce
Expand Down
6 changes: 6 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,9 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int:
return 4
case MMAType.F32_32x32x16_F8:
return 16


def all_equal(input_list: list[Any]) -> bool:
if len(input_list) == 0:
return True
return all(elem == input_list[0] for elem in input_list)
63 changes: 63 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,69 @@ def test(
# CHECK: arith.addf {{.*}} : vector<1xf16>


# Tests for multiple local reduction, and we to emit and iteratively slice and reduce over multiple variables correctly.
@run_test
def test_mutliple_local_reduce_sum():
M = tkl.sym.M
N = tkl.sym.N
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
):
lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD)
res = tkw.sum([lhs, rhs], dim=N)
tkw.write(res, c, elements_per_thread=1)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

shape = (256, 128)
a = torch.randn(shape, dtype=torch.float16)
b = torch.randn(shape, dtype=torch.float16)
c = torch.zeros((shape[0],), dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
BLOCK_M: 1,
BLOCK_N: 128,
ELEMS_PER_THREAD: 2,
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
):
print(test(a, b, c).module_op)
# CHECK: %[[LHS:.+]] = vector.load {{.*}} : memref<256x128xf16
# CHECK: %[[RHS:.+]] = vector.load {{.*}} : memref<256x128xf16
# Reduce all sources locally.
# CHECK: %[[SRC_REDUC:.+]] = arith.addf %[[LHS]], %[[RHS]] : vector<2xf16>
# Do Local Reductions.
# CHECK: %[[LOCAL_REDUC0:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [0], sizes = [1], strides = [1]}
# CHECK: %[[LOCAL_REDUC1:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [1], sizes = [1], strides = [1]}
# CHECK: %[[REDUC_0:.+]] = arith.addf %[[LOCAL_REDUC0]], %[[LOCAL_REDUC1]] : vector<1xf16>
# Expanded Global Max Reduction
# CHECK-COUNT-6: gpu.shuffle xor


# This test is to ensure that the propagation of indexing_dims between reduction and operations
# outside the reduction is working properly.
@run_test
Expand Down
Loading