Skip to content

Commit

Permalink
Add code to schedule GEMMs (iree-org#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod authored and IanNod committed Sep 30, 2024
1 parent d70abbb commit 2517210
Show file tree
Hide file tree
Showing 11 changed files with 603 additions and 11 deletions.
12 changes: 11 additions & 1 deletion shark_turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
THREAD_1 = index_symbol("$T1")
THREAD_2 = index_symbol("$T2")

# MMA symbols
# MMA symbols.
MMA_LHS = index_symbol("$MMA_LHS")
MMA_RHS = index_symbol("$MMA_RHS")
MMA_ACC = index_symbol("$MMA_ACC")

# Scheduling symbols.
READ_SHARED_DELAY = index_symbol("$READ_SHARED_DELAY")
WRITE_SHARED_DELAY = index_symbol("$WRITE_SHARED_DELAY")
READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY")
WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY")
MMA_DELAY = index_symbol("$MMA_DELAY")
SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS")
GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS")
MMA_UNITS = index_symbol("$MMA_UNITS")
33 changes: 30 additions & 3 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .._support.dtype import DataType
from .._support.regions import RegionGraph
from .base import OpDispatcher
import numpy as np

if TYPE_CHECKING:
from ..wave.constraints import Constraint
Expand Down Expand Up @@ -360,17 +361,21 @@ def update_arg(self, idx_or_name: int | str, value: CustomOp | fx.Node):
raise IndexError("Index out of range")

def copy(
self, new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None
self,
new_name: Optional[str] = None,
new_graph: Optional[fx.Graph] = None,
arg_transform: Optional[Callable[[Any], Any]] = lambda x: x,
) -> Self:
"""Returns a duplicate of this node."""
graph = new_graph
if new_graph is None:
graph = self.graph
graph.inserting_after(self.fx_node)
new_node = graph.node_copy(self.fx_node)
new_node = graph.node_copy(self.fx_node, arg_transform=arg_transform)
new_node.tkw_op = self
new_node.tkw_op_name = self.tkw_op_name
new_node.index = copy.deepcopy(self.fx_node.index)
if hasattr(self.fx_node, "index"):
new_node.index = copy.deepcopy(self.fx_node.index)
if new_name:
new_node.name = new_name
return get_custom(new_node)
Expand Down Expand Up @@ -447,6 +452,28 @@ def index(self, value: Any):
else:
raise ValueError("Index must be a dict")

@property
def rrt(self):
if hasattr(self.fx_node, "rrt"):
return self.fx_node.rrt

@rrt.setter
def rrt(self, value):
if not isinstance(value, np.ndarray):
raise ValueError("RRT must be a numpy array")
self.fx_node.rrt = value

@property
def scheduling_parameters(self):
if hasattr(self.fx_node, "scheduling_parameters"):
return self.fx_node.scheduling_parameters

@scheduling_parameters.setter
def scheduling_parameters(self, value: Any):
if not isinstance(value, dict):
raise ValueError("Scheduling parameters must be a dict")
self.fx_node.scheduling_parameters = value

def post_expansion(self, constraints: list["Constraint"]) -> None:
"""
Hook for post-expansion operations. This is called after the arguments
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/wave/scheduling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .schedule import *
69 changes: 68 additions & 1 deletion shark_turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch.fx as fx
from random import shuffle, seed, Random
from random import Random
from collections import defaultdict
from ..._support.indexing import index_symbol, IndexExpr
from .resources import *
from dataclasses import dataclass
import sympy
import math
Expand Down Expand Up @@ -214,3 +215,69 @@ def topological_sort_nodes(
filtered_nodes.add(edge._from)
sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f)
return sorted_nodes


def get_scheduling_weight(node: fx.Node) -> EdgeWeight:
"""
Get the scheduling weight of a node.
"""
custom_node = get_custom(node)
match custom_node:
case Read():
if custom_node.memory_type.address_space == GLOBAL_ADDRESS_SPACE:
weight = EdgeWeight(0, delay_table[Operation.READ_GLOBAL])
else:
weight = EdgeWeight(0, delay_table[Operation.READ_SHARED])
case Write():
if custom_node.memory_type.address_space == GLOBAL_ADDRESS_SPACE:
weight = EdgeWeight(0, delay_table[Operation.WRITE_GLOBAL])
else:
weight = EdgeWeight(0, delay_table[Operation.WRITE_SHARED])
case MMA():
weight = EdgeWeight(0, delay_table[Operation.MMA])
case IterArg():
weight = EdgeWeight(1, 0)
case _:
raise ValueError(f"Unsupported node type: {custom_node}")
weight.delay = subs_idxc(weight.delay)
weight.iteration_difference = subs_idxc(weight.iteration_difference)
return weight


def erase_placeholder_nodes(graph: fx.Graph, ignore_nodes: set[fx.Node]) -> None:
"""
This function erases nodes in the ignore list from the graph. We replace uses
of the node with None.
"""
for node in ignore_nodes:
for user in list(node.users):
idx = user.args.index(node)
user.update_arg(idx, None)
graph.erase_node(node)


def create_scheduling_edges(
graph: fx.Graph,
ignore_nodes: set[fx.Node],
iter_args: list[fx.Node],
output: fx.Node,
) -> list[Edge]:
"""
Create scheduling edges from the graph including back edges
from the outputs to the iter args. Also remove output
and placeholder nodes.
"""
# Create edges from outputs to iter args.
for return_val, iter_arg in zip(get_custom(output).return_vals[0], iter_args):
iter_arg.args = (return_val,)
graph.erase_node(output)
edges = []
for node in graph.nodes:
if node in ignore_nodes:
continue
weight = get_scheduling_weight(node)
for user in node.users:
edge = Edge(node, user, weight)
edges.append(edge)
erase_placeholder_nodes(graph, ignore_nodes)
return edges
6 changes: 4 additions & 2 deletions shark_turbine/kernel/wave/scheduling/modulo_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def all_scc_scheduled(self, sccs: dict[fx.Node, list[fx.Node]]) -> bool:
return False
return True

def schedule(self) -> dict[fx.Node, int]:
def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
"""
Schedule the graph using the Modulo Scheduler.
Returns a schedule which maps each node to a cycle.
Expand All @@ -105,6 +105,7 @@ def schedule(self) -> dict[fx.Node, int]:
# Generate the schedule.
# TODO: Come up with a better heuristic on an upper bound for the initiation interval.
T_max_range = 3 * T0
success = False
for T in range(T0, T0 + T_max_range):
logger.debug(f"Trying initiation interval: {T}.")
self.RT = np.zeros((T, len(self.resources)))
Expand Down Expand Up @@ -136,6 +137,7 @@ def schedule(self) -> dict[fx.Node, int]:
logger.debug(f"Failed to schedule SCC: {scc}.")
break
if self.all_scc_scheduled(sccs):
success = True
logger.debug(
f"Successfully scheduled all SCCs with initiation interval: {T}."
)
Expand All @@ -144,7 +146,7 @@ def schedule(self) -> dict[fx.Node, int]:
raise Exception("Failed to schedule the graph.")

self._initiation_interval = T
return self.schedule
return self.schedule, success

def scc_scheduled(
self,
Expand Down
75 changes: 75 additions & 0 deletions shark_turbine/kernel/wave/scheduling/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from ...lang.global_symbols import *
from ..utils import subs_idxc
from ...ops.wave_ops import Read, Write, MMA, IterArg, Output, get_custom
import torch.fx as fx
from enum import Enum
import numpy as np


# This table contains the number of functional units available for each operation.
def get_available_resources() -> list[int]:
resources = [GLOBAL_MEMORY_UNITS, SHARED_MEMORY_UNITS, MMA_UNITS]
return np.array([int(subs_idxc(x)) for x in resources])


class Operation(Enum):
READ_SHARED = "read_shared"
WRITE_SHARED = "write_shared"
READ_GLOBAL = "read_global"
WRITE_GLOBAL = "write_global"
MMA = "mma"
NOOP = "noop"


# This table contains the cycles required to execute each operation.
delay_table = {
Operation.READ_SHARED: READ_SHARED_DELAY,
Operation.WRITE_SHARED: WRITE_SHARED_DELAY,
Operation.READ_GLOBAL: READ_GLOBAL_DELAY,
Operation.WRITE_GLOBAL: WRITE_GLOBAL_DELAY,
Operation.MMA: MMA_DELAY,
Operation.NOOP: 0,
}

# This table contains the resource usage for each operation.
# Operations can use more than one resource for more than one cycle.
resource_reservation_table = {
Operation.READ_SHARED: np.array([[0, 1, 0]]),
Operation.WRITE_SHARED: np.array([[0, 1, 0]]),
Operation.READ_GLOBAL: np.array([[1, 0, 0]]),
Operation.WRITE_GLOBAL: np.array([[1, 0, 0]]),
Operation.MMA: np.array([[0, 0, 1]]),
Operation.NOOP: np.array([[0, 0, 0]]),
}


def annotate_resource_usage(
graph: fx.Graph,
) -> tuple[set[fx.Node], list[fx.Node], fx.Node]:
ignore_nodes = set()
iter_args = []
output = None
for node in graph.nodes:
custom = get_custom(node)
if isinstance(custom, Read):
custom.rrt = (
resource_reservation_table[Operation.READ_GLOBAL]
if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE
else resource_reservation_table[Operation.READ_SHARED]
)
elif isinstance(custom, Write):
custom.rrt = (
resource_reservation_table[Operation.WRITE_GLOBAL]
if custom.memory_type.address_space == GLOBAL_ADDRESS_SPACE
else resource_reservation_table[Operation.WRITE_SHARED]
)
elif isinstance(custom, MMA):
custom.rrt = resource_reservation_table[Operation.MMA]
elif isinstance(custom, IterArg):
iter_args.append(node)
custom.rrt = resource_reservation_table[Operation.NOOP]
elif isinstance(custom, Output):
output = node
else:
ignore_nodes.add(node)
return ignore_nodes, iter_args, output
79 changes: 79 additions & 0 deletions shark_turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ..constraints import Constraint
from ..._support.tracing import CapturedTrace
from ...ops.wave_ops import Reduction, IterArg, get_custom
from .modulo_scheduling import ModuloScheduler
from .graph_utils import create_scheduling_edges, Edge
from .resources import get_available_resources, annotate_resource_usage
from ..visualization import visualize_edges, visualize_graph, visualize_schedule
from ..utils import subs_idxc, graph_copy, erase_graph
import torch.fx as fx


def visualize_scheduling_graph(edges: list[Edge]):
visualize_edges(edges, "reduction_graph.png")


def schedule_reduction(
reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint]
):
"""
Clones the reduction graph and does the following:
1. Annotates resource usage for each node.
2. Creates edges between outputs and return args for scheduling
and assigns weights to all edges.
Does scheduling on the cloned graph and applies the schedule
to the original graph. Finally, erases the cloned graph.
"""
reduction_graph = trace.get_subgraph(reduction.subgraph_name)
graph, node_map = graph_copy(reduction_graph)
ignore_nodes, iter_args, output = annotate_resource_usage(graph)
edges = create_scheduling_edges(graph, ignore_nodes, iter_args, output)

visualize = False
if visualize:
visualize_scheduling_graph(edges)
visualize_graph(graph, "scheduling_fx_graph.png")

scheduler = ModuloScheduler(graph, edges, get_available_resources())
schedule, success = scheduler.schedule_graph()
if not success:
raise ValueError("Scheduling failed.")
if visualize:
visualize_schedule(schedule, scheduler.initiation_interval, "schedule.html")

# Apply schedule to original graph, specifying the stage
# that each node is scheduled in as well as the cycle in
# each stage when the node should be issued.
inverse_node_map = {v: k for k, v in node_map.items()}
for node, cycle in schedule.items():
if node not in inverse_node_map:
continue
custom = get_custom(inverse_node_map[node])
custom.scheduling_parameters = {
"absolute_cycle": cycle,
"cycle": cycle % scheduler.initiation_interval,
"stage": cycle // scheduler.initiation_interval,
"initiation_interval": scheduler.initiation_interval,
}
# Erase edges between outputs and iter args.
if isinstance(get_custom(node), IterArg):
node.args = ()

erase_graph(graph)


def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]):
"""
Given a graph, pipelines the reductions in the graph.
"""

def is_reduction(node: fx.Node) -> bool:
return isinstance(get_custom(node), Reduction)

reduction_nodes = trace.walk(is_reduction)
if not reduction_nodes:
return

for reduction_node in reduction_nodes:
schedule_reduction(get_custom(reduction_node), trace, constraints)
27 changes: 27 additions & 0 deletions shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,30 @@ def subs_idxc(input: Any) -> Any:
"""
idxc = IndexingContext.current()
return safe_subs(input, idxc.subs)


def graph_copy(graph: fx.Graph) -> tuple[fx.Graph, dict[fx.Node, fx.Node]]:
"""
Copy the graph and return the new graph with the nodes in node_map.
Also return the mapping of old nodes to new nodes.
"""
new_graph = fx.Graph()
node_map = {}
for node in graph.nodes:
custom = get_custom(node)
new_node = custom.copy(
new_graph=new_graph,
arg_transform=lambda x: node_map[x] if x in node_map else x,
)
node_map[node] = new_node.fx_node
return new_graph, node_map


def erase_graph(graph: fx.Graph):
"""
Erase all nodes in the graph.
"""
for node in reversed(graph.nodes):
for user in node.users:
graph.erase_node(user)
graph.erase_node(node)
Loading

0 comments on commit 2517210

Please sign in to comment.