diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 0fcb32d7..3ded53e2 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1321,7 +1321,6 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e custom = get_custom(node) innermost_dim = custom.type.symbolic_shape[-1] - offset = custom.expanded_dims[innermost_dim] # Determine whether to extract or combine. if len(args) > 1: @@ -1349,6 +1348,7 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node): # actual offset, we need to multiply by the size. The size is obtained by # computing the number of partitions using the source and target vector shapes # and dividing the incoming vector shape by the number of partitions. + offset = custom.expanded_dims[innermost_dim] num_partitions = ( target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] ) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 7cffb414..6530dac3 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -84,7 +84,6 @@ def has_strided_access(node: fx.Node) -> bool: dim: simplify_index(custom.index.get(dim, custom.index[dim])) for dim in custom.index } - shape = get_vector_shape( custom.vector_shapes, custom.register_type.symbolic_shape ) @@ -93,6 +92,7 @@ def has_strided_access(node: fx.Node) -> bool: [(dim, seq.stride) for dim, seq in simplified_index.items()], key=lambda item: item[1], ) + ops_to_combine = [] with custom.graph.inserting_before(operator): for i in range(elements_per_thread): # Non-contiguous access patterns can have varying offsets. We @@ -113,7 +113,10 @@ def has_strided_access(node: fx.Node) -> bool: ) for j, dim in enumerate(custom.register_type.symbolic_shape) } + ops_to_combine.append(write) + # Useful to handle write/read dependency + custom.replace_all_uses_with(ops_to_combine) custom.graph.erase_node(operator) @@ -148,7 +151,7 @@ def has_gpr_offsets(node: fx.Node) -> bool: for operator in strided_operators: custom = get_custom(operator) simplified_index = { - dim: simplify_index(custom.register_index.get(dim, custom.index[dim])) + dim: simplify_index(custom.index.get(dim, custom.index[dim])) for dim in custom.index } elements_per_thread = subs_idxc(custom.elements_per_thread) @@ -160,6 +163,7 @@ def has_gpr_offsets(node: fx.Node) -> bool: gpr_cur_base_offset = gpr_offset_expr.subs({GPR_NUM: 0}) cur_elem_id = 0 with custom.graph.inserting_before(operator): + ops_to_combine = [] for i in range(elements_per_thread): # Break apart Reads/Writes that has non-contiguous GPR Read/Writes. next_gpr_offset = gpr_offset_expr.subs({GPR_NUM: i + 1}) @@ -182,30 +186,62 @@ def has_gpr_offsets(node: fx.Node) -> bool: gpr_size = int(gpr_size) # Generate new Read/Write that has contiguous VGPR elements. - extract = ExtractSlice( - custom.register_, [cur_elem_id], [gpr_size], [1] - ).add_to_graph(custom.graph) - write = Write( - extract, - custom.memory, - mapping=custom.mapping, - elements_per_thread=gpr_size, - ).add_to_graph(custom.graph) - write.index = { - dim: IndexSequence( - simplified_index[dim].start.subs({GPR_NUM: cur_elem_id}), - gpr_size, - simplified_index[dim].stride, - ) - for dim in simplified_index - } - write.vector_shapes = custom.vector_shapes + if isinstance(custom, Write): + extract = ExtractSlice( + custom.register_, [cur_elem_id], [gpr_size], [1] + ).add_to_graph(custom.graph) + write = Write( + extract, + custom.memory, + mapping=custom.mapping, + elements_per_thread=gpr_size, + ).add_to_graph(custom.graph) + write.index = { + dim: IndexSequence( + simplified_index[dim].start.subs( + {GPR_NUM: cur_elem_id} + ), + gpr_size, + simplified_index[dim].stride, + ) + for dim in simplified_index + } + write.vector_shapes = custom.vector_shapes + ops_to_combine.append(write) + elif isinstance(custom, Read): + # TODO: Add support on how to handle strided reads. + read = Read( + custom.memory, + elements_per_thread=gpr_size, + mapping=custom.mapping, + _write_dependency=custom._write_dependency, + ).add_to_graph(custom.graph) + read.index = { + dim: IndexSequence( + simplified_index[dim].start.subs( + {GPR_NUM: cur_elem_id} + ), + gpr_size, + simplified_index[dim].stride, + ) + for dim in simplified_index + } + read.vector_shapes = custom.vector_shapes + ops_to_combine.append(read) # Set new current base GPR offset gpr_cur_base_offset = next_gpr_offset cur_elem_id = i + 1 - if isinstance(custom, Write): - custom.graph.erase_node(operator) + if isinstance(custom, Write): + # Useful to handle write/read dependency + custom.replace_all_uses_with(ops_to_combine) + custom.graph.erase_node(operator) + elif isinstance(custom, Read): + reshape = Reshape(ops_to_combine, custom.vector_shapes).add_to_graph( + custom.graph + ) + custom.replace_all_uses_with(reshape) + custom.graph.erase_node(custom.fx_node) def preprocess_nodes( diff --git a/iree/turbine/kernel/wave/minimize_global_loads.py b/iree/turbine/kernel/wave/minimize_global_loads.py index 1092b2c5..62b623a2 100644 --- a/iree/turbine/kernel/wave/minimize_global_loads.py +++ b/iree/turbine/kernel/wave/minimize_global_loads.py @@ -50,12 +50,13 @@ def construct_min_global_access_pattern( It takes a 1-D global offset and delinearizes it to a multi-dimensional offset and updates the access pattern accordingly. """ - thread_ids = [THREAD_0, THREAD_1, THREAD_2] + thread_ids = [THREAD_0, THREAD_1, THREAD_2, GPR_NUM] new_index = {key: index[key].subs({t: 0 for t in thread_ids}) for key in index} nd_index = delinearize_index(thread_id, shape) for i, key in enumerate(index.keys()): new_index[key].start += nd_index[i] new_index[key].size = load_elems_per_thread if i == len(index.keys()) - 1 else 1 + new_index[key].stride = 1 return new_index @@ -150,6 +151,7 @@ def add_optimized_nodes( ).add_to_graph(custom.graph) write.index = read.index optimized_writes[custom_user.memory].append(write) + write.vector_shapes = custom.vector_shapes break return optimized_writes