Skip to content

Commit

Permalink
Handle partitioned read
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Nov 14, 2024
1 parent e68ad7d commit ac73aa4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 24 deletions.
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,6 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("Malformed arguments") from e
custom = get_custom(node)
innermost_dim = custom.type.symbolic_shape[-1]
offset = custom.expanded_dims[innermost_dim]

# Determine whether to extract or combine.
if len(args) > 1:
Expand Down Expand Up @@ -1349,6 +1348,7 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
# actual offset, we need to multiply by the size. The size is obtained by
# computing the number of partitions using the source and target vector shapes
# and dividing the incoming vector shape by the number of partitions.
offset = custom.expanded_dims[innermost_dim]
num_partitions = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
)
Expand Down
80 changes: 58 additions & 22 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def has_strided_access(node: fx.Node) -> bool:
dim: simplify_index(custom.index.get(dim, custom.index[dim]))
for dim in custom.index
}

shape = get_vector_shape(
custom.vector_shapes, custom.register_type.symbolic_shape
)
Expand All @@ -93,6 +92,7 @@ def has_strided_access(node: fx.Node) -> bool:
[(dim, seq.stride) for dim, seq in simplified_index.items()],
key=lambda item: item[1],
)
ops_to_combine = []
with custom.graph.inserting_before(operator):
for i in range(elements_per_thread):
# Non-contiguous access patterns can have varying offsets. We
Expand All @@ -113,7 +113,10 @@ def has_strided_access(node: fx.Node) -> bool:
)
for j, dim in enumerate(custom.register_type.symbolic_shape)
}
ops_to_combine.append(write)

# Useful to handle write/read dependency
custom.replace_all_uses_with(ops_to_combine)
custom.graph.erase_node(operator)


Expand Down Expand Up @@ -148,7 +151,7 @@ def has_gpr_offsets(node: fx.Node) -> bool:
for operator in strided_operators:
custom = get_custom(operator)
simplified_index = {
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
dim: simplify_index(custom.index.get(dim, custom.index[dim]))
for dim in custom.index
}
elements_per_thread = subs_idxc(custom.elements_per_thread)
Expand All @@ -160,6 +163,7 @@ def has_gpr_offsets(node: fx.Node) -> bool:
gpr_cur_base_offset = gpr_offset_expr.subs({GPR_NUM: 0})
cur_elem_id = 0
with custom.graph.inserting_before(operator):
ops_to_combine = []
for i in range(elements_per_thread):
# Break apart Reads/Writes that has non-contiguous GPR Read/Writes.
next_gpr_offset = gpr_offset_expr.subs({GPR_NUM: i + 1})
Expand All @@ -182,30 +186,62 @@ def has_gpr_offsets(node: fx.Node) -> bool:
gpr_size = int(gpr_size)

# Generate new Read/Write that has contiguous VGPR elements.
extract = ExtractSlice(
custom.register_, [cur_elem_id], [gpr_size], [1]
).add_to_graph(custom.graph)
write = Write(
extract,
custom.memory,
mapping=custom.mapping,
elements_per_thread=gpr_size,
).add_to_graph(custom.graph)
write.index = {
dim: IndexSequence(
simplified_index[dim].start.subs({GPR_NUM: cur_elem_id}),
gpr_size,
simplified_index[dim].stride,
)
for dim in simplified_index
}
write.vector_shapes = custom.vector_shapes
if isinstance(custom, Write):
extract = ExtractSlice(
custom.register_, [cur_elem_id], [gpr_size], [1]
).add_to_graph(custom.graph)
write = Write(
extract,
custom.memory,
mapping=custom.mapping,
elements_per_thread=gpr_size,
).add_to_graph(custom.graph)
write.index = {
dim: IndexSequence(
simplified_index[dim].start.subs(
{GPR_NUM: cur_elem_id}
),
gpr_size,
simplified_index[dim].stride,
)
for dim in simplified_index
}
write.vector_shapes = custom.vector_shapes
ops_to_combine.append(write)
elif isinstance(custom, Read):
# TODO: Add support on how to handle strided reads.
read = Read(
custom.memory,
elements_per_thread=gpr_size,
mapping=custom.mapping,
_write_dependency=custom._write_dependency,
).add_to_graph(custom.graph)
read.index = {
dim: IndexSequence(
simplified_index[dim].start.subs(
{GPR_NUM: cur_elem_id}
),
gpr_size,
simplified_index[dim].stride,
)
for dim in simplified_index
}
read.vector_shapes = custom.vector_shapes
ops_to_combine.append(read)

# Set new current base GPR offset
gpr_cur_base_offset = next_gpr_offset
cur_elem_id = i + 1
if isinstance(custom, Write):
custom.graph.erase_node(operator)
if isinstance(custom, Write):
# Useful to handle write/read dependency
custom.replace_all_uses_with(ops_to_combine)
custom.graph.erase_node(operator)
elif isinstance(custom, Read):
reshape = Reshape(ops_to_combine, custom.vector_shapes).add_to_graph(
custom.graph
)
custom.replace_all_uses_with(reshape)
custom.graph.erase_node(custom.fx_node)


def preprocess_nodes(
Expand Down
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def construct_min_global_access_pattern(
It takes a 1-D global offset and delinearizes it to a multi-dimensional offset
and updates the access pattern accordingly.
"""
thread_ids = [THREAD_0, THREAD_1, THREAD_2]
thread_ids = [THREAD_0, THREAD_1, THREAD_2, GPR_NUM]
new_index = {key: index[key].subs({t: 0 for t in thread_ids}) for key in index}
nd_index = delinearize_index(thread_id, shape)
for i, key in enumerate(index.keys()):
new_index[key].start += nd_index[i]
new_index[key].size = load_elems_per_thread if i == len(index.keys()) - 1 else 1
new_index[key].stride = 1
return new_index


Expand Down Expand Up @@ -150,6 +151,7 @@ def add_optimized_nodes(
).add_to_graph(custom.graph)
write.index = read.index
optimized_writes[custom_user.memory].append(write)
write.vector_shapes = custom.vector_shapes
break
return optimized_writes

Expand Down

0 comments on commit ac73aa4

Please sign in to comment.