-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds the code to schedule tkw gemms. We have to construct edges for the modulo scheduler and define resources that are available for the scheduler. These resources are tunable. Signed-off-by: Harsh Menon <[email protected]>
- Loading branch information
Showing
11 changed files
with
590 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .schedule import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from ...lang.global_symbols import * | ||
from ..utils import subs_idxc | ||
from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom | ||
import torch.fx as fx | ||
from enum import Enum | ||
import numpy as np | ||
|
||
|
||
# This table contains the number of functional units available for each operation. | ||
def get_available_resources() -> list[int]: | ||
resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS] | ||
return np.array([int(subs_idxc(x)) for x in resources]) | ||
|
||
|
||
class Operation(Enum): | ||
READ_SHARED = "read_shared" | ||
WRITE_SHARED = "write_shared" | ||
READ_GLOBAL = "read_global" | ||
WRITE_GLOBAL = "write_global" | ||
MMA = "mma" | ||
NOOP = "noop" | ||
|
||
|
||
# This table contains the cycles required to execute each operation. | ||
delay_table = { | ||
Operation.READ_SHARED: READ_SHARED_DELAY, | ||
Operation.WRITE_SHARED: WRITE_SHARED_DELAY, | ||
Operation.READ_GLOBAL: READ_GLOBAL_DELAY, | ||
Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY, | ||
Operation.MMA: MMA_DELAY, | ||
Operation.NOOP: 0, | ||
} | ||
|
||
# This table contains the resource usage for each operation. | ||
# Operations can use more than one resource for more than one cycle. | ||
resource_reservation_table = { | ||
Operation.READ_SHARED: np.array([0, 1, 0]), | ||
Operation.WRITE_SHARED: np.array([0, 1, 0]), | ||
Operation.READ_GLOBAL: np.array([1, 0, 0]), | ||
Operation.WRITE_GLOBAL: np.array([1, 0, 0]), | ||
Operation.MMA: np.array([0, 0, 1]), | ||
Operation.NOOP: np.array([0, 0, 0]), | ||
} | ||
|
||
|
||
def annotate_resource_usage( | ||
graph: fx.Graph, | ||
) -> tuple[set[fx.Node], list[fx.Node], fx.Node]: | ||
ignore_nodes = set() | ||
iter_args = [] | ||
output = None | ||
for node in graph.nodes: | ||
custom = get_custom(node) | ||
if isinstance(custom, Read): | ||
custom.rrt = ( | ||
resource_reservation_table[Operation.READ_GLOBAL] | ||
if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE | ||
else resource_reservation_table[Operation.READ_SHARED] | ||
) | ||
elif isinstance(custom, Write): | ||
custom.rrt = ( | ||
resource_reservation_table[Operation.WRITE_GLOBAL] | ||
if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE | ||
else resource_reservation_table[Operation.WRITE_SHARED] | ||
) | ||
elif isinstance(custom, MMA): | ||
custom.rrt = resource_reservation_table[Operation.MMA] | ||
elif isinstance(custom, IterArg): | ||
iter_args.append(node) | ||
custom.rrt = resource_reservation_table[Operation.NOOP] | ||
elif isinstance(custom, Output): | ||
output = node | ||
else: | ||
ignore_nodes.add(node) | ||
return ignore_nodes, iter_args, output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from ..constraints import Constraint | ||
from ..._support.tracing import CapturedTrace | ||
from ...ops.wave_ops import Reduction, IterArg, get_custom | ||
from .modulo_scheduling import ModuloScheduler | ||
from .graph_utils import create_scheduling_edges, Edge | ||
from .resources import get_available_resources, annotate_resource_usage | ||
from ..visualization import visualize_edges, visualize_graph, visualize_schedule | ||
from ..utils import subs_idxc, graph_copy, erase_graph | ||
import torch.fx as fx | ||
|
||
|
||
def visualize_scheduling_graph(edges: list[Edge]): | ||
visualize_edges(edges, "reduction_graph.png") | ||
|
||
|
||
def schedule_reduction( | ||
reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint] | ||
): | ||
""" | ||
Clones the reduction graph and does the following: | ||
1. Annotates resource usage for each node. | ||
2. Creates edges between outputs and return args for scheduling | ||
and assigns weights to all edges. | ||
Does scheduling on the cloned graph and applies the schedule | ||
to the original graph. Finally, erases the cloned graph. | ||
""" | ||
reduction_graph = trace.get_subgraph(reduction.subgraph_name) | ||
graph, node_map = graph_copy(reduction_graph) | ||
ignore_nodes, iter_args, output = annotate_resource_usage(graph) | ||
edges = create_scheduling_edges(graph, ignore_nodes, iter_args, output) | ||
|
||
visualize = False | ||
if visualize: | ||
visualize_scheduling_graph(edges) | ||
visualize_graph(graph, "scheduling_fx_graph.png") | ||
|
||
scheduler = ModuloScheduler(graph, edges, get_available_resources()) | ||
schedule, success = scheduler.schedule_graph() | ||
if not success: | ||
raise ValueError("Scheduling failed.") | ||
if visualize: | ||
visualize_schedule(schedule, scheduler.initiation_interval, "schedule.html") | ||
|
||
# Apply schedule to original graph, specifying the stage | ||
# that each node is scheduled in as well as the cycle in | ||
# each stage when the node should be issued. | ||
inverse_node_map = {v: k for k, v in node_map.items()} | ||
for node, cycle in schedule.items(): | ||
if node not in inverse_node_map: | ||
continue | ||
custom = get_custom(inverse_node_map[node]) | ||
custom.scheduling_parameters = { | ||
"cycle": cycle % scheduler.initiation_interval, | ||
"stage": cycle // scheduler.initiation_interval, | ||
"initiation_interval": scheduler.initiation_interval, | ||
} | ||
# Erase edges between outputs and iter args. | ||
if isinstance(get_custom(node), IterArg): | ||
node.args = () | ||
|
||
erase_graph(graph) | ||
|
||
|
||
def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): | ||
""" | ||
Given a graph, pipelines the reductions in the graph. | ||
""" | ||
|
||
def is_reduction(node: fx.Node) -> bool: | ||
return isinstance(get_custom(node), Reduction) | ||
|
||
reduction_nodes = trace.walk(is_reduction) | ||
if not reduction_nodes: | ||
return | ||
|
||
for reduction_node in reduction_nodes: | ||
schedule_reduction(get_custom(reduction_node), trace, constraints) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.