Skip to content

Commit

Permalink
Add support for scheduling attention
Browse files Browse the repository at this point in the history
This PR adds support for scheduling in
attention operators. In particular, the following
changes are implemented:

1. Parallel Floyd Warshall allows for faster scheduling
2. Support for iter_args in rotating registers
3. Support for VALU and SHUFFLE delays and resources
4. Add infer_type for remaining wave ops

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Nov 10, 2024
1 parent ac12191 commit dd182ab
Show file tree
Hide file tree
Showing 16 changed files with 456 additions and 89 deletions.
4 changes: 4 additions & 0 deletions iree/turbine/kernel/lang/global_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
READ_GLOBAL_DELAY = index_symbol("$READ_GLOBAL_DELAY")
WRITE_GLOBAL_DELAY = index_symbol("$WRITE_GLOBAL_DELAY")
MMA_DELAY = index_symbol("$MMA_DELAY")
VALU_DELAY = index_symbol("$VALU_DELAY")
SHUFFLE_DELAY = index_symbol("$SHUFFLE_DELAY")
SHARED_MEMORY_UNITS = index_symbol("$SHARED_MEMORY_UNITS")
GLOBAL_MEMORY_UNITS = index_symbol("$GLOBAL_MEMORY_UNITS")
MMA_UNITS = index_symbol("$MMA_UNITS")
VALU_UNITS = index_symbol("$VALU_UNITS")
SHUFFLE_UNITS = index_symbol("$SHUFFLE_UNITS")
27 changes: 10 additions & 17 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,8 @@ def custom_string(self, value_map: dict[str, str]) -> str:
def indexing_dims(self) -> list[IndexSymbol]:
return list(self._type.symbolic_shape) if self._type else []

@property
def type(self) -> "Memory":
return self._type
def infer_type(self):
self.fx_node.type = self._type


@dataclass
Expand Down Expand Up @@ -854,9 +853,8 @@ class NewRegister(CustomOp):
def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

@property
def type(self) -> "Register":
return Register[*self.shape, self.dtype]
def infer_type(self):
self.type = Register[*self.shape, self.dtype]


@define_op("mma")
Expand Down Expand Up @@ -1275,11 +1273,9 @@ def target_shape(self):
def indexing_dims(self) -> list[IndexSymbol]:
return self.target_shape

@property
def type(self) -> Memory:
def infer_type(self):
src_dtype = get_custom(self.arg).type.dtype
dst_type = Register[*self.target_shape, src_dtype]
return dst_type
self.type = Register[*self.target_shape, src_dtype]


@define_interface_op("max")
Expand Down Expand Up @@ -1370,10 +1366,8 @@ class ShuffleOp(CustomOp):
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Register:
src_type = get_custom(self.arg).type
return src_type
def infer_type(self):
self.type = get_custom(self.arg).type


@define_op("cast")
Expand Down Expand Up @@ -1438,6 +1432,5 @@ class Reshape(CustomOp, ABC):
def indexing_dims(self) -> list[IndexExpr]:
return get_custom(_to_sequence(self.args)[0]).indexing_dims

@property
def type(self) -> Register:
return get_custom(_to_sequence(self.args)[0]).type
def infer_type(self):
self.type = get_custom(_to_sequence(self.args)[0]).type
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node):
mask |= get_scheduling_mask(operation)

mask = arith_d.constant(IntegerType.get_signless(32), mask)
llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask])
llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask], [], [])


@handle_op(scheduling_group_barrier)
Expand Down
78 changes: 75 additions & 3 deletions iree/turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dataclasses import dataclass
import sympy
import math
from functools import partial
import multiprocessing as mp

T = index_symbol("$INITIATION_INTERVAL")

Expand Down Expand Up @@ -157,7 +159,25 @@ def find_cycles_in_scc(scc: dict[fx.Node, list[fx.Node]]) -> list[list[fx.Node]]
return circuits


def all_pairs_longest_paths(
def all_pairs_longest_paths_helper(
graph: fx.Graph, u: fx.Node, dist: dict[tuple[fx.Node, fx.Node], IndexExpr], i: int
):
v = list(graph.nodes)[i]
for w in graph.nodes:
dist[(v, w)] = sympy.Max(dist[(v, w)], dist[(v, u)] + dist[(u, w)])
return v, dist


def all_pairs_longest_path_parallel(N: int, D: np.array, k: int, i: int):
"""
This function is called once for a different value of i.
"""
for j in range(N):
D[i, j] = np.maximum(D[i, j], D[i, k] + D[k, j])
return i, D[i]


def all_pairs_longest_paths_symbolic(
graph: fx.Graph,
edges: list[Edge],
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
Expand All @@ -181,6 +201,51 @@ def all_pairs_longest_paths(
return D


def all_pairs_longest_paths(
graph: fx.Graph,
edges: list[Edge],
T: int,
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
"""
For each node in the graph, compute the longest path to all other nodes.
Uses the Floyd-Warshall algorithm and assumes that the cycles don't
have positive weights. This function computes the distances in parallel
by parallelizing across the start nodes.
"""
N = len(graph.nodes)
D = np.zeros((N, N), dtype=np.float32)
negative_inf = -np.inf
for i in range(N):
for j in range(N):
D[i, j] = negative_inf

all_nodes = list(graph.nodes)
for edge in edges:
i = all_nodes.index(edge._from)
j = all_nodes.index(edge._to)
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
for result in results:
D[result[0]] = result[1]
pool.close()
pool.join()

# Convert from index to node based representation.
G: dict[tuple[fx.Node, fx.Node], int] = {}
for i, from_node in enumerate(graph.nodes):
for j, to_node in enumerate(graph.nodes):
if np.isinf(D[i, j]) or i == j:
continue
G[(from_node, to_node)] = int(D[i, j])

return G


def evaluate_all_pairs_longest_paths(
D: dict[tuple[fx.Node, fx.Node], IndexExpr], initiation_interval: int
) -> dict[tuple[fx.Node, fx.Node], int]:
Expand All @@ -190,7 +255,8 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
D_static[key] = D_static[key].subs(T, initiation_interval)
if isinstance(D_static[key], sympy.Expr):
D_static[key] = D_static[key].subs(T, initiation_interval)
# Remove the negative infinity values and edges to self.
for k in list(D_static.keys()):
if math.isinf(D_static[k]) or k[0] == k[1]:
Expand Down Expand Up @@ -244,8 +310,14 @@ def get_scheduling_weight(node: fx.Node) -> EdgeWeight:
weight = EdgeWeight(0, delay_table[Operation.MMA])
case IterArg():
weight = EdgeWeight(1, 0)
case CastOp():
case CastOp() | Extract() | Permute() | Broadcast() | Reshape():
weight = EdgeWeight(0, delay_table[Operation.NOOP])
case UnaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case BinaryPyOp():
weight = EdgeWeight(0, delay_table[Operation.VALU])
case ShuffleOp():
weight = EdgeWeight(0, delay_table[Operation.SHUFFLE])
case _:
raise ValueError(f"Unsupported node type: {custom_node}")
weight.delay = subs_idxc(weight.delay)
Expand Down
Loading

0 comments on commit dd182ab

Please sign in to comment.