diff --git a/shark_turbine/kernel/lang/global_symbols.py b/shark_turbine/kernel/lang/global_symbols.py index 91178ebf..efe112e8 100644 --- a/shark_turbine/kernel/lang/global_symbols.py +++ b/shark_turbine/kernel/lang/global_symbols.py @@ -15,7 +15,17 @@ THREAD_1 = index_symbol("$T1") THREAD_2 = index_symbol("$T2") -# MMA symbols +# MMA symbols. MMA_LHS = index_symbol("$MMA_LHS") MMA_RHS = index_symbol("$MMA_RHS") MMA_ACC = index_symbol("$MMA_ACC") + +# Scheduling symbols. +READ_SHARED_DELAY = index_symbol("$READ_SHARED_DELAY") +WRITE_SHARED_DELAY = index_symbol("$WRITE_SHARED_DELAY") +READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY") +WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY") +MMA_DELAY = index_symbol("$MMA_DELAY") +SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS") +GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS") +MMA_UNITS = index_symbol("$MMA_UNITS") diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index f468e602..cebf4dde 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -23,6 +23,7 @@ from .._support.dtype import DataType from .._support.regions import RegionGraph from .base import OpDispatcher +import numpy as np if TYPE_CHECKING: from ..wave.constraints import Constraint @@ -360,17 +361,21 @@ def update_arg(self, idx_or_name: int | str, value: CustomOp | fx.Node): raise IndexError("Index out of range") def copy( - self, new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None + self, + new_name: Optional[str] = None, + new_graph: Optional[fx.Graph] = None, + arg_transform: Optional[Callable[[Any], Any]] = lambda x: x, ) -> Self: """Returns a duplicate of this node.""" graph = new_graph if new_graph is None: graph = self.graph graph.inserting_after(self.fx_node) - new_node = graph.node_copy(self.fx_node) + new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform) new_node.tkw_op = self new_node.tkw_op_name = self.tkw_op_name - new_node.index = copy.deepcopy(self.fx_node.index) + if hasattr(self.fx_node, "index"): + new_node.index = copy.deepcopy(self.fx_node.index) if new_name: new_node.name = new_name return get_custom(new_node) @@ -447,6 +452,28 @@ def index(self, value: Any): else: raise ValueError("Index must be a dict") + @property + def rrt(self): + if hasattr(self.fx_node, "rrt"): + return self.fx_node.rrt + + @rrt.setter + def rrt(self, value): + if not isinstance(value, np.ndarray): + raise ValueError("RRT must be a numpy array") + self.fx_node.rrt = value + + @property + def scheduling_parameters(self): + if hasattr(self.fx_node, "scheduling_parameters"): + return self.fx_node.scheduling_parameters + + @scheduling_parameters.setter + def scheduling_parameters(self, value: Any): + if not isinstance(value, dict): + raise ValueError("Scheduling parameters must be a dict") + self.fx_node.scheduling_parameters = value + def post_expansion(self, constraints: list["Constraint"]) -> None: """ Hook for post-expansion operations. This is called after the arguments diff --git a/shark_turbine/kernel/wave/scheduling/__init__.py b/shark_turbine/kernel/wave/scheduling/__init__.py new file mode 100644 index 00000000..b7ee5706 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/__init__.py @@ -0,0 +1 @@ +from .schedule import * diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/shark_turbine/kernel/wave/scheduling/graph_utils.py index 41bd3a66..ffb12e50 100644 --- a/shark_turbine/kernel/wave/scheduling/graph_utils.py +++ b/shark_turbine/kernel/wave/scheduling/graph_utils.py @@ -2,9 +2,11 @@ from random import shuffle, seed, Random from collections import defaultdict from ..._support.indexing import index_symbol, IndexExpr +from .resources import * from dataclasses import dataclass import sympy import math +from ..utils import DCE T = index_symbol("$INITIATION_INTERVAL") @@ -214,3 +216,69 @@ def topological_sort_nodes( filtered_nodes.add(edge._from) sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f) return sorted_nodes + + +def get_scheduling_weight(node: fx.Node) -> EdgeWeight: + """ + Get the scheduling weight of a node. + """ + custom_node = get_custom(node) + match custom_node: + case Read(): + if custom_node.memory_type.address_space == GLOBAL_ADDRESS_SPACE: + weight = EdgeWeight(0, delay_table[Operation.READ_GLOBAL]) + else: + weight = EdgeWeight(0, delay_table[Operation.READ_SHARED]) + case Write(): + if custom_node.memory_type.address_space == GLOBAL_ADDRESS_SPACE: + weight = EdgeWeight(0, delay_table[Operation.WRITE_GLOBAL]) + else: + weight = EdgeWeight(0, delay_table[Operation.WRITE_SHARED]) + case MMA(): + weight = EdgeWeight(0, delay_table[Operation.MMA]) + case IterArg(): + weight = EdgeWeight(1, 0) + case _: + raise ValueError(f"Unsupported node type: {custom_node}") + weight.delay = subs_idxc(weight.delay) + weight.iteration_difference = subs_idxc(weight.iteration_difference) + return weight + + +def erase_placeholder_nodes(graph: fx.Graph, ignore_nodes: set[fx.Node]) -> None: + """ + This function erases nodes in the ignore list from the graph. We replace uses + of the node with None. + """ + for node in ignore_nodes: + for user in list(node.users): + idx = user.args.index(node) + user.update_arg(idx, None) + graph.erase_node(node) + + +def create_scheduling_edges( + graph: fx.Graph, + ignore_nodes: set[fx.Node], + iter_args: list[fx.Node], + output: fx.Node, +) -> list[Edge]: + """ + Create scheduling edges from the graph including back edges + from the outputs to the iter args. Also remove output + and placeholder nodes. + """ + # Create edges from outputs to iter args. + for return_val, iter_arg in zip(get_custom(output).return_vals[0], iter_args): + iter_arg.args = (return_val,) + graph.erase_node(output) + edges = [] + for node in graph.nodes: + if node in ignore_nodes: + continue + weight = get_scheduling_weight(node) + for user in node.users: + edge = Edge(node, user, weight) + edges.append(edge) + erase_placeholder_nodes(graph, ignore_nodes) + return edges diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py index 94323c2c..1b9b19d2 100644 --- a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -79,7 +79,7 @@ def all_scc_scheduled(self, sccs: dict[fx.Node, list[fx.Node]]) -> bool: return False return True - def schedule(self) -> dict[fx.Node, int]: + def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: """ Schedule the graph using the Modulo Scheduler. Returns a schedule which maps each node to a cycle. @@ -105,6 +105,7 @@ def schedule(self) -> dict[fx.Node, int]: # Generate the schedule. # TODO: Come up with a better heuristic on an upper bound for the initiation interval. T_max_range = 3 * T0 + success = False for T in range(T0, T0 + T_max_range): logger.debug(f"Trying initiation interval: {T}.") self.RT = np.zeros((T, len(self.resources))) @@ -136,6 +137,7 @@ def schedule(self) -> dict[fx.Node, int]: logger.debug(f"Failed to schedule SCC: {scc}.") break if self.all_scc_scheduled(sccs): + success = True logger.debug( f"Successfully scheduled all SCCs with initiation interval: {T}." ) @@ -144,7 +146,7 @@ def schedule(self) -> dict[fx.Node, int]: raise Exception("Failed to schedule the graph.") self._initiation_interval = T - return self.schedule + return self.schedule, success def scc_scheduled( self, diff --git a/shark_turbine/kernel/wave/scheduling/resources.py b/shark_turbine/kernel/wave/scheduling/resources.py new file mode 100644 index 00000000..c7e18609 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/resources.py @@ -0,0 +1,75 @@ +from ...lang.global_symbols import * +from ..utils import subs_idxc +from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom +import torch.fx as fx +from enum import Enum +import numpy as np + + +# This table contains the number of functional units available for each operation. +def get_available_resources() -> list[int]: + resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS] + return np.array([int(subs_idxc(x)) for x in resources]) + + +class Operation(Enum): + READ_SHARED = "read_shared" + WRITE_SHARED = "write_shared" + READ_GLOBAL = "read_global" + WRITE_GLOBAL = "write_global" + MMA = "mma" + NOOP = "noop" + + +# This table contains the cycles required to execute each operation. +delay_table = { + Operation.READ_SHARED: READ_SHARED_DELAY, + Operation.WRITE_SHARED: WRITE_SHARED_DELAY, + Operation.READ_GLOBAL: READ_GLOBAL_DELAY, + Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY, + Operation.MMA: MMA_DELAY, + Operation.NOOP: 0, +} + +# This table contains the resource usage for each operation. +# Operations can use more than one resource for more than one cycle. +resource_reservation_table = { + Operation.READ_SHARED: np.array([0, 1, 0]), + Operation.WRITE_SHARED: np.array([0, 1, 0]), + Operation.READ_GLOBAL: np.array([1, 0, 0]), + Operation.WRITE_GLOBAL: np.array([1, 0, 0]), + Operation.MMA: np.array([0, 0, 1]), + Operation.NOOP: np.array([0, 0, 0]), +} + + +def annotate_resource_usage( + graph: fx.Graph, +) -> tuple[set[fx.Node], list[fx.Node], fx.Node]: + ignore_nodes = set() + iter_args = [] + output = None + for node in graph.nodes: + custom = get_custom(node) + if isinstance(custom, Read): + custom.rrt = ( + resource_reservation_table[Operation.READ_GLOBAL] + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else resource_reservation_table[Operation.READ_SHARED] + ) + elif isinstance(custom, Write): + custom.rrt = ( + resource_reservation_table[Operation.WRITE_GLOBAL] + if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE + else resource_reservation_table[Operation.WRITE_SHARED] + ) + elif isinstance(custom, MMA): + custom.rrt = resource_reservation_table[Operation.MMA] + elif isinstance(custom, IterArg): + iter_args.append(node) + custom.rrt = resource_reservation_table[Operation.NOOP] + elif isinstance(custom, Output): + output = node + else: + ignore_nodes.add(node) + return ignore_nodes, iter_args, output diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/shark_turbine/kernel/wave/scheduling/schedule.py new file mode 100644 index 00000000..b4eae0e1 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/schedule.py @@ -0,0 +1,78 @@ +from ..constraints import Constraint +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import Reduction, IterArg, get_custom +from .modulo_scheduling import ModuloScheduler +from .graph_utils import create_scheduling_edges, Edge +from .resources import get_available_resources, annotate_resource_usage +from ..visualization import visualize_edges, visualize_graph, visualize_schedule +from ..utils import subs_idxc, graph_copy, erase_graph +import torch.fx as fx + + +def visualize_scheduling_graph(edges: list[Edge]): + visualize_edges(edges, "reduction_graph.png") + + +def schedule_reduction( + reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint] +): + """ + Clones the reduction graph and does the following: + 1. Annotates resource usage for each node. + 2. Creates edges between outputs and return args for scheduling + and assigns weights to all edges. + Does scheduling on the cloned graph and applies the schedule + to the original graph. Finally, erases the cloned graph. + + """ + reduction_graph = trace.get_subgraph(reduction.subgraph_name) + graph, node_map = graph_copy(reduction_graph) + ignore_nodes, iter_args, output = annotate_resource_usage(graph) + edges = create_scheduling_edges(graph, ignore_nodes, iter_args, output) + + visualize = False + if visualize: + visualize_scheduling_graph(edges) + visualize_graph(graph, "scheduling_fx_graph.png") + + scheduler = ModuloScheduler(graph, edges, get_available_resources()) + schedule, success = scheduler.schedule_graph() + if not success: + raise ValueError("Scheduling failed.") + if visualize: + visualize_schedule(schedule, scheduler.initiation_interval, "schedule.html") + + # Apply schedule to original graph, specifying the stage + # that each node is scheduled in as well as the cycle in + # each stage when the node should be issued. + inverse_node_map = {v: k for k, v in node_map.items()} + for node, cycle in schedule.items(): + if node not in inverse_node_map: + continue + custom = get_custom(inverse_node_map[node]) + custom.scheduling_parameters = { + "cycle": cycle % scheduler.initiation_interval, + "stage": cycle // scheduler.initiation_interval, + "initiation_interval": scheduler.initiation_interval, + } + # Erase edges between outputs and iter args. + if isinstance(get_custom(node), IterArg): + node.args = () + + erase_graph(graph) + + +def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): + """ + Given a graph, pipelines the reductions in the graph. + """ + + def is_reduction(node: fx.Node) -> bool: + return isinstance(get_custom(node), Reduction) + + reduction_nodes = trace.walk(is_reduction) + if not reduction_nodes: + return + + for reduction_node in reduction_nodes: + schedule_reduction(get_custom(reduction_node), trace, constraints) diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index a1883691..f526b3f1 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -325,3 +325,30 @@ def subs_idxc(input: Any) -> Any: """ idxc = IndexingContext.current() return safe_subs(input, idxc.subs) + + +def graph_copy(graph: fx.Graph) -> tuple[fx.Graph, dict[fx.Node, fx.Node]]: + """ + Copy the graph and return the new graph with the nodes in node_map. + Also return the mapping of old nodes to new nodes. + """ + new_graph = fx.Graph() + node_map = {} + for node in graph.nodes: + custom = get_custom(node) + new_node = custom.copy( + new_graph=new_graph, + arg_transform=lambda x: node_map[x] if x in node_map else x, + ) + node_map[node] = new_node.fx_node + return new_graph, node_map + + +def erase_graph(graph: fx.Graph): + """ + Erase all nodes in the graph. + """ + for node in reversed(graph.nodes): + for user in node.users: + graph.erase_node(user) + graph.erase_node(node) diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py index 7469c59c..ef4fe83e 100644 --- a/shark_turbine/kernel/wave/visualization.py +++ b/shark_turbine/kernel/wave/visualization.py @@ -3,7 +3,14 @@ import pygraphviz as pgv except: graphviz_disabled = True +pandas_disabled = False +try: + import pandas as pd +except: + matplotlib_disabled = True from torch import fx +from .scheduling.graph_utils import Edge +import math def number_nodes(graph: fx.Graph) -> dict[int, int]: @@ -22,3 +29,69 @@ def visualize_graph(graph: fx.Graph, file_name: str): G.add_edge(node_numbering[id(node)], node_numbering[id(user)]) G.layout(prog="dot") G.draw(file_name) + + +def visualize_edges(edges: list[Edge], file_name: str): + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + G = pgv.AGraph(directed=True) + node_map = {} + count = 0 + for edge in edges: + if edge._from not in node_map: + node_map[edge._from] = count + count += 1 + G.add_node(node_map[edge._from], label=f"{edge._from}") + if edge._to not in node_map: + node_map[edge._to] = count + count += 1 + G.add_node(node_map[edge._to], label=f"{edge._to}") + G.add_edge( + node_map[edge._from], + node_map[edge._to], + label=f"({edge.weight.iteration_difference}, {edge.weight.delay})", + ) + G.layout(prog="dot") + G.draw(file_name) + + +def visualize_schedule( + schedule: dict[fx.Graph, int], initiation_interval: int, file_name: str +): + if pandas_disabled: + raise ImportError("pandas not installed, cannot visualize schedule") + + max_time = max(schedule.values()) + max_stage = math.ceil(max_time / initiation_interval) + rows = max_time + 1 + max_stage * initiation_interval + cols = max_stage + + table = [["" for _ in range(cols)] for _ in range(rows)] + for stage in range(max_stage): + for key, value in schedule.items(): + table[value + stage * initiation_interval][stage] += f"{key}
" + + df = pd.DataFrame(table, columns=[f"Stage {i}" for i in range(cols)]) + s = df.style.set_properties(**{"text-align": "center"}) + s = s.set_table_styles( + [ + {"selector": "", "props": [("border", "1px solid grey")]}, + {"selector": "tbody td", "props": [("border", "1px solid grey")]}, + {"selector": "th", "props": [("border", "1px solid grey")]}, + {"selector": "th", "props": [("min-width", "300px")]}, + ] + ) + output = s.apply( + lambda x: [ + ( + "background: lightgreen" + if int(x.name) >= (max_stage - 1) * initiation_interval + and int(x.name) < max_stage * initiation_interval + else "" + ) + for _ in x + ], + axis=1, + ).to_html() + with open(f"{file_name}", "w") as f: + f.write(output) diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index 5b121984..7aecace0 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -28,6 +28,7 @@ from .index_sequence_analysis import partition_strided_operators from .shared_memory_indexing import apply_shared_memory_indexing_corrections from .register_analysis import determine_register_shape +from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr import shark_turbine.kernel.lang as tkl from .._support.tracing import ( @@ -210,12 +211,16 @@ def _trace_and_get_kernel_signature( # Partition strided operators. partition_strided_operators(graph, self.constraints) - # Add shared memory barriers. - add_shared_memory_barriers(graph) - # Decompose reduce Ops. decompose_reduce_ops(graph, self.constraints, idxc.subs) + # Schedule the reduction ops. + if kwargs.get("schedule", False): + schedule_graph(graph, self.constraints) + + # Add shared memory barriers. + add_shared_memory_barriers(graph) + # Determine grid shape. self.grid_type.dims = [1, 1, 1] for constraint in self.workgroup_constraints: diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index 08f71697..94641099 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -14,6 +14,18 @@ all_pairs_longest_paths, evaluate_all_pairs_longest_paths, ) +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.lang.global_symbols import * +from shark_turbine.kernel._support.tracing import CapturedTrace +from shark_turbine.kernel._support.indexing import IndexingContext +from shark_turbine.kernel.wave.promotion import promote_placeholders +from shark_turbine.kernel.wave.hoisting import hoist_allocs +from shark_turbine.kernel.wave.expansion import expand_graph +from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph +from shark_turbine.kernel.ops.wave_ops import get_custom class SchedulingTest(unittest.TestCase): @@ -176,7 +188,8 @@ def testModuloScheduling(self): visualize_graph(graph, "scheduling_test_graph.png") resources = np.array([1, 1]) scheduler = ModuloScheduler(graph, weighted_edges, resources) - schedule = scheduler.schedule() + schedule, success = scheduler.schedule_graph() + assert success == True assert schedule[nodes["a"]] == 0 assert schedule[nodes["b"]] == 4 assert schedule[nodes["c"]] == 5 @@ -187,6 +200,207 @@ def testModuloScheduling(self): == np.array([[1, 1], [1, 0], [1, 0], [0, 1]]) ) + def testGemmScheduling(self): + + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + ARGK = tkl.sym.ARGK + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)] + + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + + @tkw.wave_trace_only(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: 2048, + N: 10240, + K: 1280, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 5, + WRITE_GLOBAL_DELAY: 5, + MMA_DELAY: 2, + SHARED_MEMORY_UNITS: 2, + GLOBAL_MEMORY_UNITS: 2, + MMA_UNITS: 2, + } + with tk.gen.TestLaunchContext(hyperparams, canonicalize=True, schedule=True): + trace: CapturedTrace = gemm() + IndexingContext.current().finalize() + promote_placeholders(trace, constraints) + hoist_allocs(trace) + expand_graph(trace, constraints) + minimize_global_loads(trace, constraints) + schedule_graph(trace, constraints) + subgraph = trace.get_subgraph("region_0") + initiation_interval = 11 + correct_schedule = { + "acc_0_0_0": { + "cycle": 3, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "acc_1_1_0": { + "cycle": 4, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "acc_1_0_0": { + "cycle": 3, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "acc_0_1_0": { + "cycle": 4, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "read_4": { + "cycle": 0, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "write_2": { + "cycle": 5, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_0": { + "cycle": 8, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_1": { + "cycle": 7, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_1_0_0": { + "cycle": 9, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_1_0_1": { + "cycle": 9, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_5": { + "cycle": 0, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "write_3": { + "cycle": 5, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_0": { + "cycle": 8, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_0_1": { + "cycle": 7, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_1_0": { + "cycle": 6, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "read_shared_0_1_1": { + "cycle": 6, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "mma_0_0_0": { + "cycle": 10, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "mma_0_0_1": { + "cycle": 1, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_1_1_0": { + "cycle": 0, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_1_1_1": { + "cycle": 2, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_1_0_0": { + "cycle": 10, + "stage": 0, + "initiation_interval": initiation_interval, + }, + "mma_1_0_1": { + "cycle": 1, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_0_1_0": { + "cycle": 0, + "stage": 1, + "initiation_interval": initiation_interval, + }, + "mma_0_1_1": { + "cycle": 2, + "stage": 1, + "initiation_interval": initiation_interval, + }, + } + for node in subgraph.nodes: + custom = get_custom(node) + if custom.name in correct_schedule: + assert custom.scheduling_parameters == correct_schedule[custom.name] + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)