diff --git a/iree/turbine/kernel/lang/global_symbols.py b/iree/turbine/kernel/lang/global_symbols.py index e1966138..5fa542be 100644 --- a/iree/turbine/kernel/lang/global_symbols.py +++ b/iree/turbine/kernel/lang/global_symbols.py @@ -27,6 +27,10 @@ READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY") WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY") MMA_DELAY = index_symbol("$MMA_DELAY") +VALU_DELAY = index_symbol("$VALU_DELAY") +SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY") SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS") GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS") MMA_UNITS = index_symbol("$MMA_UNITS") +VALU_UNITS = index_symbol("$VALU_UNITS") +SHUFFLE_UNITS = index_symbol("$SHUFFLE_UNITS") diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 6e38ea66..6a0d9cea 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -751,9 +751,8 @@ def custom_string(self, value_map: dict[str, str]) -> str: def indexing_dims(self) -> list[IndexSymbol]: return list(self._type.symbolic_shape) if self._type else [] - @property - def type(self) -> "Memory": - return self._type + def infer_type(self): + self.fx_node.type = self._type @dataclass @@ -854,9 +853,8 @@ class NewRegister(CustomOp): def indexing_dims(self) -> list[IndexSymbol]: return list(self.shape) - @property - def type(self) -> "Register": - return Register[*self.shape, self.dtype] + def infer_type(self): + self.type = Register[*self.shape, self.dtype] @define_op("mma") @@ -1275,11 +1273,9 @@ def target_shape(self): def indexing_dims(self) -> list[IndexSymbol]: return self.target_shape - @property - def type(self) -> Memory: + def infer_type(self): src_dtype = get_custom(self.arg).type.dtype - dst_type = Register[*self.target_shape, src_dtype] - return dst_type + self.type = Register[*self.target_shape, src_dtype] @define_interface_op("max") @@ -1370,10 +1366,8 @@ class ShuffleOp(CustomOp): def indexing_dims(self) -> list[IndexSymbol]: return get_custom(self.arg).indexing_dims - @property - def type(self) -> Register: - src_type = get_custom(self.arg).type - return src_type + def infer_type(self): + self.type = get_custom(self.arg).type @define_op("cast") @@ -1438,6 +1432,5 @@ class Reshape(CustomOp, ABC): def indexing_dims(self) -> list[IndexExpr]: return get_custom(_to_sequence(self.args)[0]).indexing_dims - @property - def type(self) -> Register: - return get_custom(_to_sequence(self.args)[0]).type + def infer_type(self): + self.type = get_custom(_to_sequence(self.args)[0]).type diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 2e670824..0fcb32d7 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -1122,7 +1122,7 @@ def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node): mask |= get_scheduling_mask(operation) mask = arith_d.constant(IntegerType.get_signless(32), mask) - llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask]) + llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask], [], []) @handle_op(scheduling_group_barrier) diff --git a/iree/turbine/kernel/wave/scheduling/graph_utils.py b/iree/turbine/kernel/wave/scheduling/graph_utils.py index 34a68356..f5b9d741 100644 --- a/iree/turbine/kernel/wave/scheduling/graph_utils.py +++ b/iree/turbine/kernel/wave/scheduling/graph_utils.py @@ -12,6 +12,8 @@ from dataclasses import dataclass import sympy import math +from functools import partial +import multiprocessing as mp T = index_symbol("$INITIATION_INTERVAL") @@ -157,7 +159,25 @@ def find_cycles_in_scc(scc: dict[fx.Node, list[fx.Node]]) -> list[list[fx.Node]] return circuits -def all_pairs_longest_paths( +def all_pairs_longest_paths_helper( + graph: fx.Graph, u: fx.Node, dist: dict[tuple[fx.Node, fx.Node], IndexExpr], i: int +): + v = list(graph.nodes)[i] + for w in graph.nodes: + dist[(v, w)] = sympy.Max(dist[(v, w)], dist[(v, u)] + dist[(u, w)]) + return v, dist + + +def all_pairs_longest_path_parallel(N: int, D: np.array, k: int, i: int): + """ + This function is called once for a different value of i. + """ + for j in range(N): + D[i, j] = np.maximum(D[i, j], D[i, k] + D[k, j]) + return i, D[i] + + +def all_pairs_longest_paths_symbolic( graph: fx.Graph, edges: list[Edge], ) -> dict[tuple[fx.Node, fx.Node], IndexExpr]: @@ -181,6 +201,53 @@ def all_pairs_longest_paths( return D +def all_pairs_longest_paths( + graph: fx.Graph, + edges: list[Edge], + T: int, +) -> dict[tuple[fx.Node, fx.Node], IndexExpr]: + """ + For each node in the graph, compute the longest path to all other nodes. + Uses the Floyd-Warshall algorithm and assumes that the cycles don't + have positive weights. This function computes the distances in parallel + by parallelizing across the start nodes. + + T is the initiation interval that is computed during modulo scheduling. + """ + N = len(graph.nodes) + D = np.zeros((N, N), dtype=np.float32) + negative_inf = -np.inf + for i in range(N): + for j in range(N): + D[i, j] = negative_inf + + all_nodes = list(graph.nodes) + for edge in edges: + i = all_nodes.index(edge._from) + j = all_nodes.index(edge._to) + D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T + + # Parallel implementation + pool = mp.get_context("fork").Pool(processes=mp.cpu_count()) + for k in range(N): + func = partial(all_pairs_longest_path_parallel, N, D, k) + results = pool.map(func, range(N)) + for result in results: + D[result[0]] = result[1] + pool.close() + pool.join() + + # Convert from index to node based representation. + G: dict[tuple[fx.Node, fx.Node], int] = {} + for i, from_node in enumerate(graph.nodes): + for j, to_node in enumerate(graph.nodes): + if np.isinf(D[i, j]) or i == j: + continue + G[(from_node, to_node)] = int(D[i, j]) + + return G + + def evaluate_all_pairs_longest_paths( D: dict[tuple[fx.Node, fx.Node], IndexExpr], initiation_interval: int ) -> dict[tuple[fx.Node, fx.Node], int]: @@ -190,7 +257,8 @@ def evaluate_all_pairs_longest_paths( """ D_static = dict(D) for key in D_static: - D_static[key] = D_static[key].subs(T, initiation_interval) + if isinstance(D_static[key], sympy.Expr): + D_static[key] = D_static[key].subs(T, initiation_interval) # Remove the negative infinity values and edges to self. for k in list(D_static.keys()): if math.isinf(D_static[k]) or k[0] == k[1]: @@ -244,8 +312,14 @@ def get_scheduling_weight(node: fx.Node) -> EdgeWeight: weight = EdgeWeight(0, delay_table[Operation.MMA]) case IterArg(): weight = EdgeWeight(1, 0) - case CastOp(): + case CastOp() | Extract() | Permute() | Broadcast() | Reshape(): weight = EdgeWeight(0, delay_table[Operation.NOOP]) + case UnaryPyOp(): + weight = EdgeWeight(0, delay_table[Operation.VALU]) + case BinaryPyOp(): + weight = EdgeWeight(0, delay_table[Operation.VALU]) + case ShuffleOp(): + weight = EdgeWeight(0, delay_table[Operation.SHUFFLE]) case _: raise ValueError(f"Unsupported node type: {custom_node}") weight.delay = subs_idxc(weight.delay) diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py index 8955fdea..0f5e9877 100644 --- a/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -9,6 +9,8 @@ GetResult, get_custom, SchedulingGroupBarrier, + MMA, + NewRegister, ) from .modulo_scheduling import ModuloScheduler from ..utils import ( @@ -76,16 +78,19 @@ def add_nodes_by_schedule( logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}") custom_node = get_custom(node) logger.debug(f"Node args: {node.args}") + preferred_stage = ( + stage if pipelining_stage == PipelineStage.KERNEL else None + ) for arg in node.args: if arg_context.contains_in_iteration(iteration, arg): logger.debug( - f"Found arg: {arg} at iteration {iteration} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg)}." + f"Found arg: {arg} at iteration {iteration} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg, preferred_stage)}." ) continue new_node = custom_node.copy( new_graph=reduction_graph, arg_transform=lambda x: ( - arg_context.get_from_iteration(iteration, x) + arg_context.get_from_iteration(iteration, x, preferred_stage) if arg_context.contains_in_iteration(iteration, x) else x ), @@ -99,39 +104,71 @@ def add_nodes_by_schedule( # Set the index for the new node by substituting the induction variable # for the current iteration. new_node.index = node.index - for dim in new_node.index: - new_node.index[dim] = new_node.index[dim].subs( - {induction_variable: current_induction_variables[iteration]} - ) + if new_node.index: + for dim in new_node.index: + new_node.index[dim] = new_node.index[dim].subs( + {induction_variable: current_induction_variables[iteration]} + ) + if custom_node.expanded_dims: + new_node.expanded_dims = custom_node.expanded_dims # Add scheduling parameters for debugging. new_node.scheduling_parameters = node.scheduling_parameters # Update the rotating registers and argument context for the current node (if applicable). + old_node = None if node in rotating_registers: rotating_registers[node].append(new_node.fx_node) - rotating_registers[node].popleft() + old_node = rotating_registers[node].popleft() # If draining, then override the rotating registers and update the argument context. if fill_or_drain: for next_stage in range(stage + 1, len(stages)): arg_context[(iteration, next_stage, node)] = new_node.fx_node - # Update the init args in the argument context whenever a result is computed. + # Update the iter and init args in the argument context whenever a result is computed. if node in arg_context.results: + iter_arg = arg_context.result_to_iter_arg[node] + logger.debug( + f"Updating result: {node} -> {iter_arg} to {new_node.fx_node}." + ) + arg_context.map_arg_all_after_iteration( + iter_arg, + new_node.fx_node, + iteration, + ) + # In situations where we have an iter_arg as a rotating register, + # we also have the output as a rotating register. So when we + # are updating the output, we update the iter_arg as well with the + # old value of the output rotating register. Consider this example: + # Say we have the following: + # + # Stage 0: + # iter_arg0 + # + # + # output = compute(...) -> here we update iter_arg0 to have the output value + # for the next stage, so that it gets picked up in stage1. + # + # Stage 1: + # b = use(iter_arg0) if ( pipelining_stage == PipelineStage.KERNEL - or pipelining_stage == PipelineStage.EPILOGUE + or pipelining_stage == PipelineStage.PROLOGUE ): - logger.debug( - f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}." - ) - arg_context.map_arg_all( - arg_context.result_to_iter_arg[node], new_node.fx_node - ) + if iter_arg in rotating_registers and old_node: + logger.debug( + f"Updating rotating register iter arg {iter_arg} -> {old_node}." + ) + rotating_registers[iter_arg].append(old_node) + rotating_registers[iter_arg].popleft() + for next_stage in range(stage + 1, len(stages)): + arg_context[(iteration, next_stage, iter_arg)] = old_node if pipelining_stage == PipelineStage.PROLOGUE: logger.debug( f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}." ) - arg_context.map_arg_all( - arg_context.result_to_init_arg[node], new_node.fx_node + arg_context.map_arg_all_after_iteration( + arg_context.result_to_init_arg[node], + new_node.fx_node, + iteration, ) if pipelining_stage == PipelineStage.KERNEL and use_scheduling_barriers: @@ -154,6 +191,29 @@ def push_placeholders( arg_context.map_arg_all(node, root_node) +def add_missing_registers(graph: fx.Graph): + """ + This function goes through the graph and finds MMA operators whose accumulator + is undefined (not in the current graph). For those operators, it replaces the accumulator + with a new register of the same shape and type and with the same index. + + This is necessary in situations where the register is defined inside the loop + but when we are trying to insert MMA operators outside the loop (such as for the + prologue and epilogue). + """ + for node in graph.nodes: + custom = get_custom(node) + if isinstance(custom, MMA): + acc = get_custom(custom.acc) + if acc.graph != custom.graph: + with custom.graph.inserting_before(node): + register = NewRegister( + acc.shape, acc.dtype, acc.value + ).add_to_graph(custom.graph) + register.index = acc.index + custom.update_arg("acc", register) + + def construct_prologue( reduction_subgraph: fx.Graph, reduction: Reduction, @@ -183,6 +243,7 @@ def construct_prologue( ) # Map iter args to init args in the prologue. + original_init_args = list(reduction.init_args) for iter_arg, init_arg in zip( reduction.iter_args(reduction_subgraph), reduction.init_args ): @@ -210,9 +271,24 @@ def construct_prologue( mapped_init_arg = arg_context.lookup(init_arg) if mapped_init_arg is None: mapped_init_arg = init_arg + logger.debug(f"Mapping init_arg {init_arg} -> {mapped_init_arg}.") new_init_args.append(mapped_init_arg) reduction.init_args = new_init_args + # We may also have some iter_args as rotating registers. These will need + # to be initialized to the original init args which we do here. + iter_args = reduction.iter_args(reduction_subgraph) + for node, registers in rotating_registers.items(): + if node in iter_args: + if all(x is None for x in registers) and len(registers) == 1: + registers[0] = original_init_args[iter_args.index(node)] + + # Add missing registers. Since registers are not present + # in the scheduling code, we could end up with a situation where + # we move mma ops outside the reduction that do not have a corresponding + # register. We remedy this in the function below. + add_missing_registers(reduction.graph) + def flatten_dict_values( rotating_registers: dict[fx.Node, list[fx.Node]] @@ -269,7 +345,8 @@ def push_rotating_registers( custom = get_custom(node) stage = custom.scheduling_parameters["stage"] iteration = arg_context.get_kernel_iteration(stage) - arg_context[(iteration, stage, node)] = registers[-1] + if node not in arg_context.iter_args: + arg_context[(iteration, stage, node)] = registers[-1] for i, register in enumerate(registers): if create_new_nodes: mapped_stage = stage + len(registers) - i @@ -277,18 +354,16 @@ def push_rotating_registers( iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph) iter_arg.type = get_custom(node).type iter_arg.index = get_custom(node).index - arg_context[(mapped_iteration, mapped_stage, node)] = iter_arg new_registers.append(iter_arg) - logger.debug( - f"Mapped orig: {node_map[node]} / mapped: {iter_arg} to stage {mapped_stage}." - ) + mapped_value = iter_arg else: mapped_stage = stage + len(registers) - i - 1 mapped_iteration = arg_context.get_kernel_iteration(mapped_stage) - arg_context[(mapped_iteration, mapped_stage, node)] = register - logger.debug( - f"Mapped orig: {node_map[node]} / mapped: {register} to stage {mapped_stage} at iteration {mapped_iteration}." - ) + mapped_value = register + arg_context[(mapped_iteration, mapped_stage, node)] = mapped_value + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {mapped_value} to stage {mapped_stage} at iteration {mapped_iteration}." + ) count += 1 if new_registers: new_rotating_registers[node] = new_registers @@ -325,7 +400,7 @@ def construct_kernel( init_args=reduction.init_args + flatten_dict_values(rotating_registers), subgraph_name="pipelined_reduction", implicit_captures=reduction.implicit_captures, - ).add_to_graph(reduction.graph) + ).add_to_graph(reduction.graph, type=reduction.type) pipelined_reduction.index = reduction.index pipelined_reduction_graph = fx.Graph() reduction.graph.subgraphs["pipelined_reduction"] = pipelined_reduction_graph @@ -386,6 +461,12 @@ def construct_kernel( "kernel.png", ) + # Add missing registers. Since registers are not present + # in the scheduling code, we could end up with a situation where + # we move mma ops outside the reduction that do not have a corresponding + # register. We remedy this in the function below. + add_missing_registers(pipelined_reduction_graph) + return pipelined_reduction, pipelined_reduction_graph @@ -422,16 +503,34 @@ def construct_epilogue( scheduler.num_stages, ) - existing_get_results: list[GetResult] = sorted( - [x for x in pipelined_reduction.users if isinstance(x, GetResult)], - key=lambda x: x.res_idx, - ) - existing_users = {x: x.users for x in existing_get_results} + existing_get_results: list[GetResult] = [ + x for x in pipelined_reduction.users if isinstance(x, GetResult) + ] + existing_indices = [x.res_idx for x in existing_get_results] # Map the results from the kernel to the init args (for stages). - for iter_arg, get_result in zip( - reduction.iter_args(reduction_subgraph), existing_get_results - ): + # The number of iter args may not be the same as the number of get results + # and so we have to add additional get results for the missing iter args. + # This happens if some of the iter args have no uses outside the reduction + # (such as the max value in flash attention). While they may not have any + # uses in the original reduction, they will have uses in the pipelined + # reduction outside the reduction and so need to be added in the correct order. + iter_args = reduction.iter_args(reduction_subgraph) + for i in range(len(iter_args)): + if i in existing_indices: + continue + with pipelined_reduction.graph.inserting_before( + existing_get_results[0].fx_node.next + ): + result = GetResult(pipelined_reduction.fx_node, i).add_to_graph( + pipelined_reduction.graph, type=iter_args[i].type + ) + existing_get_results.append(get_custom(result)) + + existing_get_results = sorted(existing_get_results, key=lambda x: x.res_idx) + existing_users = {x: x.users for x in existing_get_results} + + for iter_arg, get_result in zip(iter_args, existing_get_results): arg_context.map_arg_all(iter_arg, get_result.fx_node) with pipelined_reduction.graph.inserting_before( @@ -441,10 +540,12 @@ def construct_epilogue( # argument map with them. rotating_registers_get_results = [] offset = len(existing_get_results) - for i in range(len(flatten_dict_values(rotating_registers))): + flattened_rotating_registers = flatten_dict_values(rotating_registers) + for i in range(len(flattened_rotating_registers)): rotating_registers_get_results.append( GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph( - pipelined_reduction.graph + pipelined_reduction.graph, + type=flattened_rotating_registers[i].type, ) ) rotating_registers = unflatten_dict_values( @@ -474,6 +575,12 @@ def construct_epilogue( for i, get_result in enumerate(existing_get_results): replace_uses_in(existing_users, get_result, new_results[i]) + # Add missing registers. Since registers are not present + # in the scheduling code, we could end up with a situation where + # we move mma ops outside the reduction that do not have a corresponding + # register. We remedy this in the function below. + add_missing_registers(pipelined_reduction.graph) + if visualize: visualize_mapped_graphs( pipelined_reduction.graph, @@ -499,7 +606,7 @@ def construct_pipelined_loop( with a prologue, kernel and epilogue. """ induction_variable = get_induction_variable(reduction, constraints) - num_rotating_registers = liveness_analysis(graph, constraints, scheduler) + num_rotating_registers = liveness_analysis(graph, reduction) rotating_registers: dict[fx.Node, deque[fx.Node]] = { k: deque([None for _ in range(v)]) for k, v in num_rotating_registers.items() } diff --git a/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py index b6993a21..f5adf0ca 100644 --- a/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py +++ b/iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py @@ -1,7 +1,15 @@ from ..constraints import Constraint, TilingConstraint from ..._support.indexing import IndexSymbol from ..._support.tracing import CapturedTrace -from ...ops.wave_ops import Reduction, IterArg, Output, Write, GetResult, get_custom +from ...ops.wave_ops import ( + Reduction, + IterArg, + Output, + Write, + GetResult, + get_custom, + Placeholder, +) from .modulo_scheduling import ModuloScheduler from ..utils import graph_copy, erase_graph from ..utils import subs_idxc @@ -55,6 +63,17 @@ def map_arg_all(self, from_: fx.Node, to_: fx.Node) -> None: for stage in range(self.num_stages): self.argument_map[iteration][stage][from_] = to_ + def map_arg_all_after_iteration( + self, from_: fx.Node, to_: fx.Node, iteration: int + ) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + after the specified iteration. + """ + for iteration in range(iteration + 1, self.num_iterations): + for stage in range(self.num_stages): + self.argument_map[iteration][stage][from_] = to_ + def map_arg_all_iterations(self, stage: int, from_: fx.Node, to_: fx.Node) -> None: """ Maps the given argument from one to another into the argument context for all stages @@ -139,7 +158,7 @@ def lookup(self, key: fx.Node) -> Optional[fx.Node]: """ Looks up the argument mapping for the given node. """ - for iteration in range(self.num_iterations): + for iteration in range(self.num_iterations - 1, -1, -1): for stage in range(self.num_stages): if key in self.argument_map[iteration][stage]: return self.argument_map[iteration][stage][key] @@ -155,10 +174,15 @@ def contains_in_iteration(self, iteration: int, key: fx.Node) -> bool: for stage in range(self.num_stages) ) - def get_from_iteration(self, iteration: int, key: fx.Node) -> fx.Node: + def get_from_iteration(self, iteration: int, key: fx.Node, stage: int) -> fx.Node: """ - Gets the argument mapping for the given iteration. + Gets the argument mapping for the given iteration with + preference to the given stage. """ + + if stage and key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + for stage in range(self.num_stages): if key in self.argument_map[iteration][stage]: return self.argument_map[iteration][stage][key] @@ -207,9 +231,7 @@ def create_drain_stage_schedule(n: int) -> list[list[int]]: return schedule -def liveness_analysis( - graph: fx.Graph, constraints: list[Constraint], scheduler: ModuloScheduler -) -> dict[fx.Node, int]: +def liveness_analysis(graph: fx.Graph, reduction: Reduction) -> dict[fx.Node, int]: """ Perform liveness analysis on the graph to determine the live ranges of variables and use that to deduce how many rotating registers we need. @@ -227,11 +249,11 @@ def liveness_analysis( logger.debug( f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}" ) - lifetime[node] = max( + user_lifetime = ( user.scheduling_parameters["stage"] - - custom.scheduling_parameters["stage"], - lifetime[node], + - custom.scheduling_parameters["stage"] ) + lifetime[node] = max(user_lifetime, lifetime[node]) # Determine how many copies we need for each node. If the lifetime of a node # is l clocks and the initiation interval is T, then only ceil(l/T) values diff --git a/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py index 82940113..00c6cd78 100644 --- a/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -106,9 +106,6 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: # Initialize initiation interval. T0 = int(max(self.compute_resource_ii(), self.compute_recurrence_ii(sccs))) - # Compute symbolic all pairs longest path. - e_star_symbolic = all_pairs_longest_paths(self.graph, self.edges) - # Generate the schedule. # TODO: Come up with a better heuristic on an upper bound for the initiation interval. T_max_range = 3 * T0 @@ -116,7 +113,7 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: for T in range(T0, T0 + T_max_range): logger.debug(f"Trying initiation interval: {T}.") self.RT = np.zeros((T, len(self.resources))) - self.e_star = evaluate_all_pairs_longest_paths(e_star_symbolic, T) + self.e_star = all_pairs_longest_paths(self.graph, self.edges, T) logger.debug(f"All Pairs Longest Paths: {self.e_star}.") self.schedule: dict[fx.Node, int] = {} for _, scc in topological_sort(sccs).items(): diff --git a/iree/turbine/kernel/wave/scheduling/resources.py b/iree/turbine/kernel/wave/scheduling/resources.py index 15dd76d3..128a59fd 100644 --- a/iree/turbine/kernel/wave/scheduling/resources.py +++ b/iree/turbine/kernel/wave/scheduling/resources.py @@ -15,6 +15,13 @@ get_custom, CustomOp, CastOp, + UnaryPyOp, + BinaryPyOp, + ShuffleOp, + Permute, + Extract, + Broadcast, + Reshape, ) import torch.fx as fx from enum import Enum @@ -23,7 +30,13 @@ # This table contains the number of functional units available for each operation. def get_available_resources() -> list[int]: - resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS] + resources = [ + GLOBAL_MEMORY_UNITS, + SHARED_MEMORY_UNITS, + MMA_UNITS, + VALU_UNITS, + SHUFFLE_UNITS, + ] return np.array([int(subs_idxc(x)) for x in resources]) @@ -37,8 +50,11 @@ class Operation(Enum): VALU = "valu" SALU = "salu" NOOP = "noop" + SHUFFLE = "shuffle" +SCHEDULING_NOOPS = (IterArg, Permute, Extract, Broadcast, CastOp, Reshape) + # This table contains the cycles required to execute each operation. delay_table = { Operation.READ_SHARED: READ_SHARED_DELAY, @@ -47,17 +63,21 @@ class Operation(Enum): Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY, Operation.MMA: MMA_DELAY, Operation.NOOP: 0, + Operation.VALU: VALU_DELAY, + Operation.SHUFFLE: SHUFFLE_DELAY, } # This table contains the resource usage for each operation. # Operations can use more than one resource for more than one cycle. resource_reservation_table = { - Operation.READ_SHARED: np.array([[0, 1, 0]]), - Operation.WRITE_SHARED: np.array([[0, 1, 0]]), - Operation.READ_GLOBAL: np.array([[1, 0, 0]]), - Operation.WRITE_GLOBAL: np.array([[1, 0, 0]]), - Operation.MMA: np.array([[0, 0, 1]]), - Operation.NOOP: np.array([[0, 0, 0]]), + Operation.READ_SHARED: np.array([[0, 1, 0, 0, 0]]), + Operation.WRITE_SHARED: np.array([[0, 1, 0, 0, 0]]), + Operation.READ_GLOBAL: np.array([[1, 0, 0, 0, 0]]), + Operation.WRITE_GLOBAL: np.array([[1, 0, 0, 0, 0]]), + Operation.MMA: np.array([[0, 0, 1, 0, 0]]), + Operation.NOOP: np.array([[0, 0, 0, 0, 0]]), + Operation.VALU: np.array([[0, 0, 0, 1, 0]]), + Operation.SHUFFLE: np.array([[0, 0, 0, 0, 1]]), } @@ -76,10 +96,12 @@ def get_custom_operation_type(custom: CustomOp) -> Operation: ) elif isinstance(custom, MMA): return Operation.MMA - elif isinstance(custom, IterArg): - return Operation.NOOP - elif isinstance(custom, Output): + elif isinstance(custom, SCHEDULING_NOOPS + (Output,)): return Operation.NOOP + elif isinstance(custom, (UnaryPyOp, BinaryPyOp)): + return Operation.VALU + elif isinstance(custom, ShuffleOp): + return Operation.SHUFFLE else: return None @@ -106,8 +128,13 @@ def annotate_resource_usage( ) elif isinstance(custom, MMA): custom.rrt = resource_reservation_table[Operation.MMA] - elif isinstance(custom, (IterArg, CastOp)): - iter_args.append(node) + elif isinstance(custom, ShuffleOp): + custom.rrt = resource_reservation_table[Operation.SHUFFLE] + elif isinstance(custom, (UnaryPyOp, BinaryPyOp)): + custom.rrt = resource_reservation_table[Operation.VALU] + elif isinstance(custom, SCHEDULING_NOOPS): + if isinstance(custom, IterArg): + iter_args.append(node) custom.rrt = resource_reservation_table[Operation.NOOP] elif isinstance(custom, Output): output = node diff --git a/iree/turbine/kernel/wave/scheduling/schedule.py b/iree/turbine/kernel/wave/scheduling/schedule.py index 9cf6eb19..6e44b5f8 100644 --- a/iree/turbine/kernel/wave/scheduling/schedule.py +++ b/iree/turbine/kernel/wave/scheduling/schedule.py @@ -6,7 +6,7 @@ from ..constraints import Constraint from ..._support.tracing import CapturedTrace -from ...ops.wave_ops import Reduction, IterArg, get_custom +from ...ops.wave_ops import Reduction, IterArg, get_custom, CustomOp from .modulo_scheduling import ModuloScheduler from .graph_utils import create_scheduling_edges, Edge from .resources import get_available_resources, annotate_resource_usage @@ -15,6 +15,7 @@ from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc import torch.fx as fx from ....support.logging import get_logger +import math logger = get_logger("turbine.wave.scheduling.schedule") @@ -59,6 +60,7 @@ def schedule_reduction( # that each node is scheduled in as well as the cycle in # each stage when the node should be issued. inverse_node_map = {v: k for k, v in node_map.items()} + iter_args: list[CustomOp] = [] for node, cycle in schedule.items(): if node not in inverse_node_map: continue @@ -72,6 +74,16 @@ def schedule_reduction( # Erase edges between outputs and iter args. if isinstance(get_custom(node), IterArg): node.args = () + iter_args.append(custom) + + for custom in iter_args: + cycle = min([x.scheduling_parameters["absolute_cycle"] for x in custom.users]) + custom.scheduling_parameters = { + "absolute_cycle": cycle, + "cycle": cycle % scheduler.initiation_interval, + "stage": cycle // scheduler.initiation_interval, + "initiation_interval": scheduler.initiation_interval, + } erase_graph(graph) diff --git a/iree/turbine/kernel/wave/visualization.py b/iree/turbine/kernel/wave/visualization.py index d6438bfc..adf0990f 100644 --- a/iree/turbine/kernel/wave/visualization.py +++ b/iree/turbine/kernel/wave/visualization.py @@ -178,7 +178,11 @@ def visualize_mapped_graphs( # Draw edges between rotating registers for the same variable. for node in rotating_registers: - all_registers = [k for k, v in flat_inverse_map.items() if v == node] + all_registers = [ + k + for k, v in flat_inverse_map.items() + if v == node and k in second_numbering + ] for second, first in zip(all_registers[:-1], all_registers[1:]): G.add_edge( second_numbering[id(first)], diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index d691379a..3c58cc15 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -31,6 +31,130 @@ STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD +@run_test +def test_attention_pipelined(): + shape = (8, 128, 128, 64, 256) + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = tkw.MMAType.F32_16x16x16_F16 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: 16, N: 16}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave(constraints) + def base_attention_pipelined( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 2, + SHUFFLE_UNITS: 2, + } + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=True, + use_scheduling_barriers=False, + ): + torch.manual_seed(0) + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) + output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + print(base_attention_pipelined(q, k, v, output).module_op) + + # CHECK: func.func @base_attention_pipelined + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-13: {{.*}} = amdgpu.mfma + # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-3: {{.*}} = amdgpu.mfma + # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} + + @run_test def test_attention_32x32x8(): shape = (8, 128, 128, 64, 256) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 1aaab53e..c011d6cd 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1141,9 +1141,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, }, canonicalize=True, schedule=True, diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 2f7780bc..00810403 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -89,6 +89,10 @@ def test_gemm_pipelined(): SHARED_MEMORY_UNITS: 2, GLOBAL_MEMORY_UNITS: 2, MMA_UNITS: 2, + VALU_DELAY: 1, + VALU_UNITS: 2, + SHUFFLE_DELAY: 1, + SHUFFLE_UNITS: 2, } ): trace: CapturedTrace = gemm_pipelined() diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 80f01963..4da75483 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -179,9 +179,8 @@ def testGraphUtils(self): def testAPLP(self): graph, weighted_edges, nodes = self.create_weighted_graph() - D = all_pairs_longest_paths(graph, weighted_edges) T = 4 - D3 = evaluate_all_pairs_longest_paths(D, T) + D3 = all_pairs_longest_paths(graph, weighted_edges, T) assert D3[(nodes["a"], nodes["b"])] == 2 assert D3[(nodes["a"], nodes["c"])] == 3 assert D3[(nodes["a"], nodes["d"])] == 4 @@ -274,6 +273,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: SHARED_MEMORY_UNITS: 2, GLOBAL_MEMORY_UNITS: 2, MMA_UNITS: 2, + VALU_DELAY: 1, + VALU_UNITS: 2, + SHUFFLE_DELAY: 1, + SHUFFLE_UNITS: 2, } with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True): trace: CapturedTrace = gemm() @@ -290,21 +293,21 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: initiation_interval = 5 correct_schedule = { "acc_1_1_0": { - "absolute_cycle": 14, - "cycle": 4, + "absolute_cycle": 10, + "cycle": 0, "stage": 2, "initiation_interval": initiation_interval, }, "acc_1_0_0": { - "absolute_cycle": 14, - "cycle": 4, + "absolute_cycle": 10, + "cycle": 0, "stage": 2, "initiation_interval": initiation_interval, }, "acc_0_1_0": { - "absolute_cycle": 13, - "cycle": 3, - "stage": 2, + "absolute_cycle": 9, + "cycle": 4, + "stage": 1, "initiation_interval": initiation_interval, }, "read_4": { diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index c9b67705..067619da 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -351,7 +351,7 @@ def repeat( @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_attention")) -@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize("enable_scheduling", [False, True]) @pytest.mark.parametrize( "mfma_variant", [ @@ -476,9 +476,13 @@ def repeat( READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 2, + SHUFFLE_UNITS: 2, } config = get_default_run_config() if run_bench: diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 35d989ab..c7df2037 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -143,9 +143,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = get_default_run_config() if run_bench: @@ -262,9 +266,13 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = get_default_run_config() if run_bench: @@ -377,9 +385,13 @@ def repeat( READ_GLOBAL_DELAY: 2, WRITE_GLOBAL_DELAY: 2, MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, SHARED_MEMORY_UNITS: 4, GLOBAL_MEMORY_UNITS: 4, MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, } config = get_default_run_config() if run_bench: