Skip to content

Commit

Permalink
Add support for scheduling in attention operators (#253)
Browse files Browse the repository at this point in the history
This PR adds support for scheduling in
attention operators. In particular, the following
changes are implemented:

1. Parallel Floyd Warshall allows for faster scheduling
2. Support for iter_args in rotating registers
3. Support for VALU and SHUFFLE delays and resources
4. Add infer_type for remaining wave ops

---------

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Nov 14, 2024
1 parent 90475c1 commit 79ec575
Show file tree
Hide file tree
Showing 16 changed files with 488 additions and 97 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY")
WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY")
MMA_DELAY = index_symbol("$MMA_DELAY")
VALU_DELAY = index_symbol("$VALU_DELAY")
SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY")
SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS")
GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS")
MMA_UNITS = index_symbol("$MMA_UNITS")
VALU_UNITS = index_symbol("$VALU_UNITS")
SHUFFLE_UNITS = index_symbol("$SHUFFLE_UNITS")
27 changes: 10 additions & 17 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,8 @@ def custom_string(self, value_map: dict[str, str]) -> str:
def indexing_dims(self) -> list[IndexSymbol]:
return list(self._type.symbolic_shape) if self._type else []

@property
def type(self) -> "Memory":
return self._type
def infer_type(self):
self.fx_node.type = self._type


@dataclass
Expand Down Expand Up @@ -854,9 +853,8 @@ class NewRegister(CustomOp):
def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

@property
def type(self) -> "Register":
return Register[*self.shape, self.dtype]
def infer_type(self):
self.type = Register[*self.shape, self.dtype]


@define_op("mma")
Expand Down Expand Up @@ -1275,11 +1273,9 @@ def target_shape(self):
def indexing_dims(self) -> list[IndexSymbol]:
return self.target_shape

@property
def type(self) -> Memory:
def infer_type(self):
src_dtype = get_custom(self.arg).type.dtype
dst_type = Register[*self.target_shape, src_dtype]
return dst_type
self.type = Register[*self.target_shape, src_dtype]


@define_interface_op("max")
Expand Down Expand Up @@ -1370,10 +1366,8 @@ class ShuffleOp(CustomOp):
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Register:
src_type = get_custom(self.arg).type
return src_type
def infer_type(self):
self.type = get_custom(self.arg).type


@define_op("cast")
Expand Down Expand Up @@ -1438,6 +1432,5 @@ class Reshape(CustomOp, ABC):
def indexing_dims(self) -> list[IndexExpr]:
return get_custom(_to_sequence(self.args)[0]).indexing_dims

@property
def type(self) -> Register:
return get_custom(_to_sequence(self.args)[0]).type
def infer_type(self):
self.type = get_custom(_to_sequence(self.args)[0]).type
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node):
mask |= get_scheduling_mask(operation)

mask = arith_d.constant(IntegerType.get_signless(32), mask)
llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask])
llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask], [], [])


@handle_op(scheduling_group_barrier)
Expand Down
80 changes: 77 additions & 3 deletions iree/turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dataclasses import dataclass
import sympy
import math
from functools import partial
import multiprocessing as mp

T = index_symbol("$INITIATION_INTERVAL")

Expand Down Expand Up @@ -157,7 +159,25 @@ def find_cycles_in_scc(scc: dict[fx.Node, list[fx.Node]]) -> list[list[fx.Node]]
return circuits


def all_pairs_longest_paths(
def all_pairs_longest_paths_helper(
graph: fx.Graph, u: fx.Node, dist: dict[tuple[fx.Node, fx.Node], IndexExpr], i: int
):
v = list(graph.nodes)[i]
for w in graph.nodes:
dist[(v, w)] = sympy.Max(dist[(v, w)], dist[(v, u)] + dist[(u, w)])
return v, dist


def all_pairs_longest_path_parallel(N: int, D: np.array, k: int, i: int):
"""
This function is called once for a different value of i.
"""
for j in range(N):
D[i, j] = np.maximum(D[i, j], D[i, k] + D[k, j])
return i, D[i]


def all_pairs_longest_paths_symbolic(
graph: fx.Graph,
edges: list[Edge],
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
Expand All @@ -181,6 +201,53 @@ def all_pairs_longest_paths(
return D


def all_pairs_longest_paths(
graph: fx.Graph,
edges: list[Edge],
T: int,
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
"""
For each node in the graph, compute the longest path to all other nodes.
Uses the Floyd-Warshall algorithm and assumes that the cycles don't
have positive weights. This function computes the distances in parallel
by parallelizing across the start nodes.
T is the initiation interval that is computed during modulo scheduling.
"""
N = len(graph.nodes)
D = np.zeros((N, N), dtype=np.float32)
negative_inf = -np.inf
for i in range(N):
for j in range(N):
D[i, j] = negative_inf

all_nodes = list(graph.nodes)
for edge in edges:
i = all_nodes.index(edge._from)
j = all_nodes.index(edge._to)
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
for result in results:
D[result[0]] = result[1]
pool.close()
pool.join()

# Convert from index to node based representation.
G: dict[tuple[fx.Node, fx.Node], int] = {}
for i, from_node in enumerate(graph.nodes):
for j, to_node in enumerate(graph.nodes):
if np.isinf(D[i, j]) or i == j:
continue
G[(from_node, to_node)] = int(D[i, j])

return G


def evaluate_all_pairs_longest_paths(
D: dict[tuple[fx.Node, fx.Node], IndexExpr], initiation_interval: int
) -> dict[tuple[fx.Node, fx.Node], int]:
Expand All @@ -190,7 +257,8 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
D_static[key] = D_static[key].subs(T, initiation_interval)
if isinstance(D_static[key], sympy.Expr):
D_static[key] = D_static[key].subs(T, initiation_interval)
# Remove the negative infinity values and edges to self.
for k in list(D_static.keys()):
if math.isinf(D_static[k]) or k[0] == k[1]:
Expand Down Expand Up @@ -244,8 +312,14 @@ def get_scheduling_weight(node: fx.Node) -> EdgeWeight:
weight = EdgeWeight(0, delay_table[Operation.MMA])
case IterArg():
weight = EdgeWeight(1, 0)
case CastOp():
case CastOp() | Extract() | Permute() | Broadcast() | Reshape():
weight = EdgeWeight(0, delay_table[Operation.NOOP])
case UnaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case BinaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case ShuffleOp():
weight = EdgeWeight(0, delay_table[Operation.SHUFFLE])
case _:
raise ValueError(f"Unsupported node type: {custom_node}")
weight.delay = subs_idxc(weight.delay)
Expand Down
Loading

0 comments on commit 79ec575

Please sign in to comment.