From 3a2a6ba9ea7c14d26fc39cf74fd69a5485d756b2 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Mon, 9 Sep 2024 11:18:27 -0700 Subject: [PATCH] Add code to schedule GEMMs This PR adds the code to schedule tkw gemms. We have to construct edges for the modulo scheduler and define resources that are available for the scheduler. These resources are tunable. Signed-off-by: Harsh Menon --- shark_turbine/kernel/lang/global_symbols.py | 12 +- shark_turbine/kernel/ops/wave_ops.py | 33 ++- .../kernel/wave/scheduling/__init__.py | 1 + .../kernel/wave/scheduling/graph_utils.py | 68 ++++++ .../wave/scheduling/modulo_scheduling.py | 6 +- .../kernel/wave/scheduling/resources.py | 75 ++++++ .../kernel/wave/scheduling/schedule.py | 78 +++++++ shark_turbine/kernel/wave/utils.py | 27 +++ shark_turbine/kernel/wave/visualization.py | 73 ++++++ shark_turbine/kernel/wave/wave.py | 11 +- tests/kernel/wave/scheduling_test.py | 216 +++++++++++++++++- 11 files changed, 590 insertions(+), 10 deletions(-) create mode 100644 shark_turbine/kernel/wave/scheduling/__init__.py create mode 100644 shark_turbine/kernel/wave/scheduling/resources.py create mode 100644 shark_turbine/kernel/wave/scheduling/schedule.py diff --git a/shark_turbine/kernel/lang/global_symbols.py b/shark_turbine/kernel/lang/global_symbols.py index 91178ebf..efe112e8 100644 --- a/shark_turbine/kernel/lang/global_symbols.py +++ b/shark_turbine/kernel/lang/global_symbols.py @@ -15,7 +15,17 @@ THREAD_1 = index_symbol("$T1") THREAD_2 = index_symbol("$T2") -# MMA symbols +# MMA symbols. MMA_LHS = index_symbol("$MMA_LHS") MMA_RHS = index_symbol("$MMA_RHS") MMA_ACC = index_symbol("$MMA_ACC") + +# Scheduling symbols. +READ_SHARED_DELAY = index_symbol("$READ_SHARED_DELAY") +WRITE_SHARED_DELAY = index_symbol("$WRITE_SHARED_DELAY") +READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY") +WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY") +MMA_DELAY = index_symbol("$MMA_DELAY") +SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS") +GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS") +MMA_UNITS = index_symbol("$MMA_UNITS") diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index f468e602..cebf4dde 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -23,6 +23,7 @@ from .._support.dtype import DataType from .._support.regions import RegionGraph from .base import OpDispatcher +import numpy as np if TYPE_CHECKING: from ..wave.constraints import Constraint @@ -360,17 +361,21 @@ def update_arg(self, idx_or_name: int | str, value: CustomOp | fx.Node): raise IndexError("Index out of range") def copy( - self, new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None + self, + new_name: Optional[str] = None, + new_graph: Optional[fx.Graph] = None, + arg_transform: Optional[Callable[[Any], Any]] = lambda x: x, ) -> Self: """Returns a duplicate of this node.""" graph = new_graph if new_graph is None: graph = self.graph graph.inserting_after(self.fx_node) - new_node = graph.node_copy(self.fx_node) + new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform) new_node.tkw_op = self new_node.tkw_op_name = self.tkw_op_name - new_node.index = copy.deepcopy(self.fx_node.index) + if hasattr(self.fx_node, "index"): + new_node.index = copy.deepcopy(self.fx_node.index) if new_name: new_node.name = new_name return get_custom(new_node) @@ -447,6 +452,28 @@ def index(self, value: Any): else: raise ValueError("Index must be a dict") + @property + def rrt(self): + if hasattr(self.fx_node, "rrt"): + return self.fx_node.rrt + + @rrt.setter + def rrt(self, value): + if not isinstance(value, np.ndarray): + raise ValueError("RRT must be a numpy array") + self.fx_node.rrt = value + + @property + def scheduling_parameters(self): + if hasattr(self.fx_node, "scheduling_parameters"): + return self.fx_node.scheduling_parameters + + @scheduling_parameters.setter + def scheduling_parameters(self, value: Any): + if not isinstance(value, dict): + raise ValueError("Scheduling parameters must be a dict") + self.fx_node.scheduling_parameters = value + def post_expansion(self, constraints: list["Constraint"]) -> None: """ Hook for post-expansion operations. This is called after the arguments diff --git a/shark_turbine/kernel/wave/scheduling/__init__.py b/shark_turbine/kernel/wave/scheduling/__init__.py new file mode 100644 index 00000000..b7ee5706 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/__init__.py @@ -0,0 +1 @@ +from .schedule import * diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/shark_turbine/kernel/wave/scheduling/graph_utils.py index 41bd3a66..ffb12e50 100644 --- a/shark_turbine/kernel/wave/scheduling/graph_utils.py +++ b/shark_turbine/kernel/wave/scheduling/graph_utils.py @@ -2,9 +2,11 @@ from random import shuffle, seed, Random from collections import defaultdict from ..._support.indexing import index_symbol, IndexExpr +from .resources import * from dataclasses import dataclass import sympy import math +from ..utils import DCE T = index_symbol("$INITIATION_INTERVAL") @@ -214,3 +216,69 @@ def topological_sort_nodes( filtered_nodes.add(edge._from) sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f) return sorted_nodes + + +def get_scheduling_weight(node: fx.Node) -> EdgeWeight: + """ + Get the scheduling weight of a node. + """ + custom_node = get_custom(node) + match custom_node: + case Read(): + if custom_node.memory_type.address_space == GLOBAL_ADDRESS_SPACE: + weight = EdgeWeight(0, delay_table[Operation.READ_GLOBAL]) + else: + weight = EdgeWeight(0, delay_table[Operation.READ_SHARED]) + case Write(): + if custom_node.memory_type.address_space == GLOBAL_ADDRESS_SPACE: + weight = EdgeWeight(0, delay_table[Operation.WRITE_GLOBAL]) + else: + weight = EdgeWeight(0, delay_table[Operation.WRITE_SHARED]) + case MMA(): + weight = EdgeWeight(0, delay_table[Operation.MMA]) + case IterArg(): + weight = EdgeWeight(1, 0) + case _: + raise ValueError(f"Unsupported node type: {custom_node}") + weight.delay = subs_idxc(weight.delay) + weight.iteration_difference = subs_idxc(weight.iteration_difference) + return weight + + +def erase_placeholder_nodes(graph: fx.Graph, ignore_nodes: set[fx.Node]) -> None: + """ + This function erases nodes in the ignore list from the graph. We replace uses + of the node with None. + """ + for node in ignore_nodes: + for user in list(node.users): + idx = user.args.index(node) + user.update_arg(idx, None) + graph.erase_node(node) + + +def create_scheduling_edges( + graph: fx.Graph, + ignore_nodes: set[fx.Node], + iter_args: list[fx.Node], + output: fx.Node, +) -> list[Edge]: + """ + Create scheduling edges from the graph including back edges + from the outputs to the iter args. Also remove output + and placeholder nodes. + """ + # Create edges from outputs to iter args. + for return_val, iter_arg in zip(get_custom(output).return_vals[0], iter_args): + iter_arg.args = (return_val,) + graph.erase_node(output) + edges = [] + for node in graph.nodes: + if node in ignore_nodes: + continue + weight = get_scheduling_weight(node) + for user in node.users: + edge = Edge(node, user, weight) + edges.append(edge) + erase_placeholder_nodes(graph, ignore_nodes) + return edges diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py index 94323c2c..1b9b19d2 100644 --- a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -79,7 +79,7 @@ def all_scc_scheduled(self, sccs: dict[fx.Node, list[fx.Node]]) -> bool: return False return True - def schedule(self) -> dict[fx.Node, int]: + def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: """ Schedule the graph using the Modulo Scheduler. Returns a schedule which maps each node to a cycle. @@ -105,6 +105,7 @@ def schedule(self) -> dict[fx.Node, int]: # Generate the schedule. # TODO: Come up with a better heuristic on an upper bound for the initiation interval. T_max_range = 3 * T0 + success = False for T in range(T0, T0 + T_max_range): logger.debug(f"Trying initiation interval: {T}.") self.RT = np.zeros((T, len(self.resources))) @@ -136,6 +137,7 @@ def schedule(self) -> dict[fx.Node, int]: logger.debug(f"Failed to schedule SCC: {scc}.") break if self.all_scc_scheduled(sccs): + success = True logger.debug( f"Successfully scheduled all SCCs with initiation interval: {T}." ) @@ -144,7 +146,7 @@ def schedule(self) -> dict[fx.Node, int]: raise Exception("Failed to schedule the graph.") self._initiation_interval = T - return self.schedule + return self.schedule, success def scc_scheduled( self, diff --git a/shark_turbine/kernel/wave/scheduling/resources.py b/shark_turbine/kernel/wave/scheduling/resources.py new file mode 100644 index 00000000..c7e18609 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/resources.py @@ -0,0 +1,75 @@ +from ...lang.global_symbols import * +from ..utils import subs_idxc +from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom +import torch.fx as fx +from enum import Enum +import numpy as np + + +# This table contains the number of functional units available for each operation. +def get_available_resources() -> list[int]: + resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS] + return np.array([int(subs_idxc(x)) for x in resources]) + + +class Operation(Enum): + READ_SHARED = "read_shared" + WRITE_SHARED = "write_shared" + READ_GLOBAL = "read_global" + WRITE_GLOBAL = "write_global" + MMA = "mma" + NOOP = "noop" + + +# This table contains the cycles required to execute each operation. +delay_table = { + Operation.READ_SHARED: READ_SHARED_DELAY, + Operation.WRITE_SHARED: WRITE_SHARED_DELAY, + Operation.READ_GLOBAL: READ_GLOBAL_DELAY, + Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY, + Operation.MMA: MMA_DELAY, + Operation.NOOP: 0, +} + +# This table contains the resource usage for each operation. +# Operations can use more than one resource for more than one cycle. +resource_reservation_table = { + Operation.READ_SHARED: np.array([0, 1, 0]), + Operation.WRITE_SHARED: np.array([0, 1, 0]), + Operation.READ_GLOBAL: np.array([1, 0, 0]), + Operation.WRITE_GLOBAL: np.array([1, 0, 0]), + Operation.MMA: np.array([0, 0, 1]), + Operation.NOOP: np.array([0, 0, 0]), +} + + +def annotate_resource_usage( + graph: fx.Graph, +) -> tuple[set[fx.Node], list[fx.Node], fx.Node]: + ignore_nodes = set() + iter_args = [] + output = None + for node in graph.nodes: + custom = get_custom(node) + if isinstance(custom, Read): + custom.rrt = ( + resource_reservation_table[Operation.READ_GLOBAL] + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else resource_reservation_table[Operation.READ_SHARED] + ) + elif isinstance(custom, Write): + custom.rrt = ( + resource_reservation_table[Operation.WRITE_GLOBAL] + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else resource_reservation_table[Operation.WRITE_SHARED] + ) + elif isinstance(custom, MMA): + custom.rrt = resource_reservation_table[Operation.MMA] + elif isinstance(custom, IterArg): + iter_args.append(node) + custom.rrt = resource_reservation_table[Operation.NOOP] + elif isinstance(custom, Output): + output = node + else: + ignore_nodes.add(node) + return ignore_nodes, iter_args, output diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/shark_turbine/kernel/wave/scheduling/schedule.py new file mode 100644 index 00000000..b4eae0e1 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/schedule.py @@ -0,0 +1,78 @@ +from ..constraints import Constraint +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import Reduction, IterArg, get_custom +from .modulo_scheduling import ModuloScheduler +from .graph_utils import create_scheduling_edges, Edge +from .resources import get_available_resources, annotate_resource_usage +from ..visualization import visualize_edges, visualize_graph, visualize_schedule +from ..utils import subs_idxc, graph_copy, erase_graph +import torch.fx as fx + + +def visualize_scheduling_graph(edges: list[Edge]): + visualize_edges(edges, "reduction_graph.png") + + +def schedule_reduction( + reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint] +): + """ + Clones the reduction graph and does the following: + 1. Annotates resource usage for each node. + 2. Creates edges between outputs and return args for scheduling + and assigns weights to all edges. + Does scheduling on the cloned graph and applies the schedule + to the original graph. Finally, erases the cloned graph. + + """ + reduction_graph = trace.get_subgraph(reduction.subgraph_name) + graph, node_map = graph_copy(reduction_graph) + ignore_nodes, iter_args, output = annotate_resource_usage(graph) + edges = create_scheduling_edges(graph, ignore_nodes, iter_args, output) + + visualize = False + if visualize: + visualize_scheduling_graph(edges) + visualize_graph(graph, "scheduling_fx_graph.png") + + scheduler = ModuloScheduler(graph, edges, get_available_resources()) + schedule, success = scheduler.schedule_graph() + if not success: + raise ValueError("Scheduling failed.") + if visualize: + visualize_schedule(schedule, scheduler.initiation_interval, "schedule.html") + + # Apply schedule to original graph, specifying the stage + # that each node is scheduled in as well as the cycle in + # each stage when the node should be issued. + inverse_node_map = {v: k for k, v in node_map.items()} + for node, cycle in schedule.items(): + if node not in inverse_node_map: + continue + custom = get_custom(inverse_node_map[node]) + custom.scheduling_parameters = { + "cycle": cycle % scheduler.initiation_interval, + "stage": cycle // scheduler.initiation_interval, + "initiation_interval": scheduler.initiation_interval, + } + # Erase edges between outputs and iter args. + if isinstance(get_custom(node), IterArg): + node.args = () + + erase_graph(graph) + + +def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): + """ + Given a graph, pipelines the reductions in the graph. + """ + + def is_reduction(node: fx.Node) -> bool: + return isinstance(get_custom(node), Reduction) + + reduction_nodes = trace.walk(is_reduction) + if not reduction_nodes: + return + + for reduction_node in reduction_nodes: + schedule_reduction(get_custom(reduction_node), trace, constraints) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index a1883691..f526b3f1 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -325,3 +325,30 @@ def subs_idxc(input: Any) -> Any: """ idxc = IndexingContext.current() return safe_subs(input, idxc.subs) + + +def graph_copy(graph: fx.Graph) -> tuple[fx.Graph, dict[fx.Node, fx.Node]]: + """ + Copy the graph and return the new graph with the nodes in node_map. + Also return the mapping of old nodes to new nodes. + """ + new_graph = fx.Graph() + node_map = {} + for node in graph.nodes: + custom = get_custom(node) + new_node = custom.copy( + new_graph=new_graph, + arg_transform=lambda x: node_map[x] if x in node_map else x, + ) + node_map[node] = new_node.fx_node + return new_graph, node_map + + +def erase_graph(graph: fx.Graph): + """ + Erase all nodes in the graph. + """ + for node in reversed(graph.nodes): + for user in node.users: + graph.erase_node(user) + graph.erase_node(node) diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py index 7469c59c..ef4fe83e 100644 --- a/shark_turbine/kernel/wave/visualization.py +++ b/shark_turbine/kernel/wave/visualization.py @@ -3,7 +3,14 @@ import pygraphviz as pgv except: graphviz_disabled = True +pandas_disabled = False +try: + import pandas as pd +except: + matplotlib_disabled = True from torch import fx +from .scheduling.graph_utils import Edge +import math def number_nodes(graph: fx.Graph) -> dict[int, int]: @@ -22,3 +29,69 @@ def visualize_graph(graph: fx.Graph, file_name: str): G.add_edge(node_numbering[id(node)], node_numbering[id(user)]) G.layout(prog="dot") G.draw(file_name) + + +def visualize_edges(edges: list[Edge], file_name: str): + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + G = pgv.AGraph(directed=True) + node_map = {} + count = 0 + for edge in edges: + if edge._from not in node_map: + node_map[edge._from] = count + count += 1 + G.add_node(node_map[edge._from], label=f"{edge._from}") + if edge._to not in node_map: + node_map[edge._to] = count + count += 1 + G.add_node(node_map[edge._to], label=f"{edge._to}") + G.add_edge( + node_map[edge._from], + node_map[edge._to], + label=f"({edge.weight.iteration_difference}, {edge.weight.delay})", + ) + G.layout(prog="dot") + G.draw(file_name) + + +def visualize_schedule( + schedule: dict[fx.Graph, int], initiation_interval: int, file_name: str +): + if pandas_disabled: + raise ImportError("pandas not installed, cannot visualize schedule") + + max_time = max(schedule.values()) + max_stage = math.ceil(max_time / initiation_interval) + rows = max_time + 1 + max_stage * initiation_interval + cols = max_stage + + table = [["" for _ in range(cols)] for _ in range(rows)] + for stage in range(max_stage): + for key, value in schedule.items(): + table[value + stage * initiation_interval][stage] += f"{key}
" + + df = pd.DataFrame(table, columns=[f"Stage {i}" for i in range(cols)]) + s = df.style.set_properties(**{"text-align": "center"}) + s = s.set_table_styles( + [ + {"selector": "", "props": [("border", "1px solid grey")]}, + {"selector": "tbody td", "props": [("border", "1px solid grey")]}, + {"selector": "th", "props": [("border", "1px solid grey")]}, + {"selector": "th", "props": [("min-width", "300px")]}, + ] + ) + output = s.apply( + lambda x: [ + ( + "background: lightgreen" + if int(x.name) >= (max_stage - 1) * initiation_interval + and int(x.name) < max_stage * initiation_interval + else "" + ) + for _ in x + ], + axis=1, + ).to_html() + with open(f"{file_name}", "w") as f: + f.write(output) diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index 5b121984..7aecace0 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -28,6 +28,7 @@ from .index_sequence_analysis import partition_strided_operators from .shared_memory_indexing import apply_shared_memory_indexing_corrections from .register_analysis import determine_register_shape +from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr import shark_turbine.kernel.lang as tkl from .._support.tracing import ( @@ -210,12 +211,16 @@ def _trace_and_get_kernel_signature( # Partition strided operators. partition_strided_operators(graph, self.constraints) - # Add shared memory barriers. - add_shared_memory_barriers(graph) - # Decompose reduce Ops. decompose_reduce_ops(graph, self.constraints, idxc.subs) + # Schedule the reduction ops. + if kwargs.get("schedule", False): + schedule_graph(graph, self.constraints) + + # Add shared memory barriers. + add_shared_memory_barriers(graph) + # Determine grid shape. self.grid_type.dims = [1, 1, 1] for constraint in self.workgroup_constraints: diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 08f71697..94641099 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -14,6 +14,18 @@ all_pairs_longest_paths, evaluate_all_pairs_longest_paths, ) +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.lang.global_symbols import * +from shark_turbine.kernel._support.tracing import CapturedTrace +from shark_turbine.kernel._support.indexing import IndexingContext +from shark_turbine.kernel.wave.promotion import promote_placeholders +from shark_turbine.kernel.wave.hoisting import hoist_allocs +from shark_turbine.kernel.wave.expansion import expand_graph +from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph +from shark_turbine.kernel.ops.wave_ops import get_custom class SchedulingTest(unittest.TestCase): @@ -176,7 +188,8 @@ def testModuloScheduling(self): visualize_graph(graph, "scheduling_test_graph.png") resources = np.array([1, 1]) scheduler = ModuloScheduler(graph, weighted_edges, resources) - schedule = scheduler.schedule() + schedule, success = scheduler.schedule_graph() + assert success == True assert schedule[nodes["a"]] == 0 assert schedule[nodes["b"]] == 4 assert schedule[nodes["c"]] == 5 @@ -187,6 +200,207 @@ def testModuloScheduling(self): == np.array([[1, 1], [1, 0], [1, 0], [0, 1]]) ) + def testGemmScheduling(self): + + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + ARGK = tkl.sym.ARGK + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + + @tkw.wave_trace_only(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: 2048, + N: 10240, + K: 1280, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 5, + WRITE_GLOBAL_DELAY: 5, + MMA_DELAY: 2, + SHARED_MEMORY_UNITS: 2, + GLOBAL_MEMORY_UNITS: 2, + MMA_UNITS: 2, + } + with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True): + trace: CapturedTrace = gemm() + IndexingContext.current().finalize() + promote_placeholders(trace, constraints) + hoist_allocs(trace) + expand_graph(trace, constraints) + minimize_global_loads(trace, constraints) + schedule_graph(trace, constraints) + subgraph = trace.get_subgraph("region_0") + initiation_interval = 11 + correct_schedule = { + "acc_0_0_0": { + "cycle": 3, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "acc_1_1_0": { + "cycle": 4, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "acc_1_0_0": { + "cycle": 3, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "acc_0_1_0": { + "cycle": 4, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "read_4": { + "cycle": 0, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "write_2": { + "cycle": 5, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_0": { + "cycle": 8, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_1": { + "cycle": 7, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_1_0_0": { + "cycle": 9, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_1_0_1": { + "cycle": 9, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_5": { + "cycle": 0, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "write_3": { + "cycle": 5, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_0": { + "cycle": 8, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_1": { + "cycle": 7, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_1_0": { + "cycle": 6, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_1_1": { + "cycle": 6, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "mma_0_0_0": { + "cycle": 10, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "mma_0_0_1": { + "cycle": 1, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_1_1_0": { + "cycle": 0, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_1_1_1": { + "cycle": 2, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_1_0_0": { + "cycle": 10, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "mma_1_0_1": { + "cycle": 1, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_0_1_0": { + "cycle": 0, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_0_1_1": { + "cycle": 2, + "stage": 1, + "initiation_interval": initiation_interval, + }, + } + for node in subgraph.nodes: + custom = get_custom(node) + if custom.name in correct_schedule: + assert custom.scheduling_parameters == correct_schedule[custom.name] + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)