Skip to content

Commit

Permalink
Address comments iree-org#2
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Oct 30, 2024
1 parent 8863678 commit 9d6bb0f
Showing 1 changed file with 35 additions and 22 deletions.
57 changes: 35 additions & 22 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,28 +261,41 @@ def is_mma(node):

# Determine if any reshapes are required. Reshapes are added for
# chained matmuls when the vector shapes of the operands in one matmul
# differ from those in another matmul.
for src in mma_nodes:
custom_src = get_custom(src)
for dst in mma_nodes:
if src == dst:
continue
custom_dst = get_custom(dst)
lhs_slice = capture_backward_slice(custom_dst.lhs)
rhs_slice = capture_backward_slice(custom_dst.rhs)
if src in lhs_slice or src in rhs_slice:
with custom_dst.graph.inserting_before(dst):
for i, arg in custom_dst.node_args.items():
if is_reshape_needed(
arg, custom_dst.vector_shapes, custom_src.vector_shapes
):
reshape = Reshape(
arg.fx_node, custom_src.vector_shapes
).add_to_graph(custom.graph)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)
# differ from those in another matmul. The mma_slices contain all the ops
# in the backward slice of the lhs and rhs upto a previous mma (if one exists).
# So we check for the previous node of the first operator in the slice to see
# if it is an MMA and if so check if a reshape is required.
def add_reshape_if_needed(mma: MMA, prev_mma: MMA):
with mma.graph.inserting_before(mma.fx_node):
for i, arg in mma.node_args.items():
if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes):
reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph(
custom.graph
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)

def find_mma_in_slice(node: CustomOp) -> Optional[MMA]:
"""
Find the closest mma by iterating through the backward slice of a node
in reverse.
"""
slice = list(capture_backward_slice(node))
for arg in reversed(slice):
prev_mma = get_custom(arg)
if isinstance(prev_mma, MMA):
return prev_mma
return None

for mma in mma_nodes:
custom_mma = get_custom(mma)
prev_mma = find_mma_in_slice(custom_mma.lhs) or find_mma_in_slice(
custom_mma.rhs
)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)

return mapping, mma_slices

Expand Down

0 comments on commit 9d6bb0f

Please sign in to comment.