Skip to content

Commit

Permalink
Add type inference
Browse files Browse the repository at this point in the history
This PR adds a type inference pass to wave. Previously,
the types were infered by looking up types from neighbors
resulting in inefficient type inference.

Instead, we now introduce a pass that infers the types for
all operators in the graph and the inferred type is then
stored in the node. New nodes that are constructed in
downstream passes are responsible for annotating types
for the new operators.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Nov 5, 2024
1 parent ee62366 commit a606651
Show file tree
Hide file tree
Showing 19 changed files with 385 additions and 50 deletions.
75 changes: 35 additions & 40 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,16 @@ def vector_shapes(self) -> dict[IndexSymbol, int]:
def vector_shapes(self, value: dict[IndexSymbol, int]):
self.fx_node.vector_shapes = value

@property
def type(self) -> Any:
if hasattr(self.fx_node, "type"):
return self.fx_node.type
return None

@type.setter
def type(self, value: Any):
self.fx_node.type = value

def align_index(self, constraints: list["Constraint"]) -> None:
"""
Align index to WG/Tile sizes.
Expand Down Expand Up @@ -602,10 +612,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
def infer_type(self, lhs_type: Register, rhs_type: Register) -> Register:
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
return lhs_type
Expand Down Expand Up @@ -637,9 +644,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
src_type = get_custom(self.arg).type
def infer_type(self, src_type: Register) -> Register:
return src_type


Expand Down Expand Up @@ -868,10 +873,6 @@ def rhs_type(self) -> Memory:
def acc_type(self) -> Memory:
return get_custom(self.acc).type

@property
def type(self) -> Memory:
return self.acc_type

def operand_index(
self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr]
) -> dict[IndexSymbol, IndexSequence]:
Expand Down Expand Up @@ -925,6 +926,7 @@ def reduction_dim(self, value: IndexSymbol):
@define_op("read")
@dataclass
class Read(CustomOp):

memory: fx.Proxy
elements_per_thread: Optional[Any] = None
mapping: Optional[IndexMapping] = None
Expand All @@ -937,10 +939,12 @@ def indexing_dims(self) -> list[IndexSymbol]:
# TODO: This could contain ints.
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Register":
dtype = self.memory_type.dtype
return Register[*self.indexing_dims, dtype]
def infer_type(self, memory_type: Memory) -> "Register":
dtype = memory_type.dtype
shape = memory_type.symbolic_shape
if self.mapping is not None:
shape = self.mapping.output_shape
return Register[*shape, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1052,9 +1056,7 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:
captured_vars.append(nested_node)
return captured_vars

@property
def type(self) -> Memory | Register | list[Memory | Register]:
res_types = [get_custom(x).type for x in self.init_args]
def infer_type(self, res_types: list[Register]) -> Register | list[Register]:
if len(res_types) == 1:
res_types = res_types[0]
return res_types
Expand Down Expand Up @@ -1112,9 +1114,13 @@ def indexing_dims(self) -> list[IndexSymbol]:
# TODO: This could contain ints.
return list(self.type.symbolic_shape)

@property
def type(self) -> "Memory":
return get_custom(self.memory).type
def infer_type(self, memory_type: Memory) -> "Memory":
dtype = memory_type.dtype
shape = memory_type.symbolic_shape
address_space = memory_type.address_space
if self.mapping is not None:
shape = self.mapping.input_shape
return Memory[*shape, address_space, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1144,9 +1150,7 @@ class GetResult(CustomOp):
value: fx.Node
res_idx: int

@property
def type(self) -> "Memory":
src_type = get_custom(self.value).type
def infer_type(self, src_type: Register, idx: int) -> "Memory":
if isinstance(src_type, list):
return src_type[self.res_idx]
else:
Expand Down Expand Up @@ -1200,8 +1204,7 @@ class Extract(CustomOp):
register_: fx.Proxy
offset: IndexExpr | int

@property
def type(self) -> "Register":
def infer_type(self, src_type) -> "Register":
# Intuition here is we are trying to extract an element
# from fastest dim => we reduce the fastest dim.
src_type = get_custom(self.register_).type
Expand Down Expand Up @@ -1297,13 +1300,8 @@ def indexing_dims(self) -> list[IndexSymbol]:
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
if isinstance(self.arg, Sequence):
# Local import to break circular dep.
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
def infer_type(self, src_types: list[Register] | Register) -> Register:
if isinstance(src_types, Sequence):
ref_shape = src_types[0].symbolic_shape
ref_dtype = src_types[0].dtype
if not all(
Expand All @@ -1315,7 +1313,7 @@ def type(self) -> Memory:
)
src_type = src_types[0]
else:
src_type = get_custom(self.arg).type
src_type = src_types
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type
Expand Down Expand Up @@ -1376,9 +1374,8 @@ class CastOp(CustomOp, ABC):
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
src_shape = get_custom(self.arg).type.symbolic_shape
def infer_type(self, src_type: Register) -> Register:
src_shape = src_type.symbolic_shape
return Register[*src_shape, self.dtype]


Expand All @@ -1397,9 +1394,7 @@ class Permute(CustomOp, ABC):
def indexing_dims(self) -> list[IndexExpr]:
return self.target_shape

@property
def type(self) -> Register:
src_type = get_custom(self.arg).type
def infer_type(self, src_type: Register) -> Register:
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
Expand Down
6 changes: 6 additions & 0 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def emit_sources_reduction(
binary_fn: Callable, src: list[fx.Node], graph: fx.Graph
) -> fx.Node:
init = src[0]
op_type = init.type
for i in range(1, len(src)):
init = get_graph_node(binary_fn(init, src[i]), graph)
init.type = op_type
init.index = src[0].index
return init

Expand All @@ -97,9 +99,12 @@ def emit_local_reduction(
binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int
) -> fx.Node:
init = get_graph_node(Extract(src, [0]), graph)
init.type = get_custom(init).infer_type(src.type)
for i in range(1, local_reduction_size):
cur_slice = get_graph_node(Extract(src, [i]), graph)
cur_slice.type = get_custom(cur_slice).infer_type(src.type)
init = get_graph_node(binary_fn(init, cur_slice), graph)
init.type = cur_slice.type
return init


Expand All @@ -117,6 +122,7 @@ def emit_global_reduction(
shuffle_val = ShuffleOp(init, cluster_stride, subgroup_size)
shuffle_node = get_graph_node(shuffle_val, graph)
init = get_graph_node(binary_fn(init, shuffle_node), graph)
init.type = src.type
cluster_stride <<= 1
return init

Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def _expand_reduction(
new_node = GetResult(reduction.fx_node, len(new_output_args))
new_node.add_to_graph(reduction.graph)
new_node.fx_node.name = get_expanded_name(new_node, dims)
new_node.type = arg.type
context[
(reduction, get_indexed_dims(dims, expand_dims), arg_idx)
] = new_node
Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def has_strided_access(node: fx.Node) -> bool:
)
for j, dim in enumerate(custom.register_type.symbolic_shape)
}
write.type = custom.memory.type

custom.graph.erase_node(operator)

Expand Down
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr
from ..ops.wave_ops import Read, Write, get_custom
from ..lang.global_symbols import *
from .utils import delinearize_index, DCE, subs_idxc, ceildiv
from .utils import delinearize_index, DCE, subs_idxc, ceildiv, memory_to_register
from math import prod
import torch.fx as fx
from collections import defaultdict
Expand Down Expand Up @@ -140,6 +140,7 @@ def add_optimized_nodes(
load_elems_per_thread,
materialized_shape,
)
read.type = memory_to_register(memory.type)
for custom_user in custom.users:
if (
isinstance(custom_user, Write)
Expand All @@ -149,6 +150,7 @@ def add_optimized_nodes(
read, custom_user.memory, load_elems_per_thread
).add_to_graph(custom.graph)
write.index = read.index
write.type = custom_user.type
optimized_writes[custom_user.memory].append(write)
break
return optimized_writes
Expand Down
3 changes: 3 additions & 0 deletions iree/turbine/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate):
).add_to_graph(custom_node.graph)
custom_read = get_custom(promoted_read)
custom_read.write_dependency = [promoted_write]
custom_read.type = custom_node.type
custom_write = get_custom(promoted_write)
custom_write.type = allocate_node.type
custom_node.memory_type.address_space = GLOBAL_ADDRESS_SPACE


Expand Down
10 changes: 5 additions & 5 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ def construct_epilogue(
rotating_registers_get_results = []
offset = len(existing_get_results)
for i in range(len(flatten_dict_values(rotating_registers))):
rotating_registers_get_results.append(
GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph(
pipelined_reduction.graph
)
)
get_result = GetResult(
pipelined_reduction.fx_node, i + offset
).add_to_graph(pipelined_reduction.graph)
get_result.type = pipelined_reduction.init_args[i + offset].type
rotating_registers_get_results.append(get_result)
rotating_registers = unflatten_dict_values(
num_rotating_registers, rotating_registers_get_results
)
Expand Down
89 changes: 89 additions & 0 deletions iree/turbine/kernel/wave/type_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import *
from .._support.tracing import CapturedTrace
import torch.fx as fx
from typing import Sequence
from ...support.logging import get_logger

logger = get_logger("turbine.wave.type_inference")


class TypeInferer:
def __init__(self):
self.type_table: dict[CustomOp, Memory | Register | list[Register]] = {}

def get_type(self, op: fx.Node) -> Memory | Register | list[Register]:
custom = get_custom(op)
custom_type = self.type_table.get(custom, None)
if custom_type is None and custom not in self.type_table:
raise ValueError(f"No type found for {op}")
return custom_type

def infer_types(self, op: CustomOp):
match op:
case BinaryPyOp():
s = self.get_type(op.lhs)
t = self.get_type(op.rhs)
self.type_table[op] = op.infer_type(s, t)
case GetResult():
s = self.get_type(op.value)
self.type_table[op] = op.infer_type(s, op.res_idx)
case Read() | Write():
s = self.get_type(op.memory)
self.type_table[op] = op.infer_type(s)
case MMA():
s = self.get_type(op.lhs)
t = self.get_type(op.rhs)
u = self.get_type(op.acc)
self.type_table[op] = u
case Placeholder() | NewRegister():
self.type_table[op] = op.type
case Reduction():
s = []
for init_arg in op.init_args:
s.append(self.get_type(init_arg))
self.type_table[op] = op.infer_type(s)
case ReduceOp():
args = op.arg
if not isinstance(op.arg, Sequence):
args = [op.arg]
s = []
for arg in args:
s.append(self.get_type(arg))
self.type_table[op] = op.infer_type(s)
case CastOp() | Permute() | UnaryPyOp():
s = self.get_type(op.arg)
self.type_table[op] = op.infer_type(s)
case Output():
s = []
for ret_vals in op.return_vals:
if ret_vals is None:
s = None
break
if not isinstance(ret_vals, Sequence):
ret_vals = [ret_vals]
for ret_val in ret_vals:
s.append(self.get_type(ret_val))
self.type_table[op] = s
return


def infer_types(trace: CapturedTrace | fx.Graph):
inferer = TypeInferer()
# First, infer the types for all the nodes.
for subgraph in trace.region_graph.subgraphs.values():
for node in subgraph.nodes:
custom = get_custom(node)
inferer.infer_types(custom)
# Then, set the types.
for subgraph in trace.region_graph.subgraphs.values():
for node in subgraph.nodes:
custom = get_custom(node)
if not isinstance(custom, (Placeholder, NewRegister)):
custom.type = inferer.get_type(custom.fx_node)
logger.debug(f"Setting type for {custom.fx_node} = {custom.type}")
12 changes: 10 additions & 2 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
GetResult,
IterArg,
Reshape,
Memory,
Register,
)
from .constraints import (
Constraint,
Expand Down Expand Up @@ -136,8 +138,8 @@ def is_removable_operator(node: fx.Node) -> bool:
custom = get_custom(node)
idxc = IndexingContext.current()
is_global_write = isinstance(custom, Write) and (
custom.type.address_space.subs(idxc.subs) == GLOBAL_ADDRESS_SPACE
or custom.type.address_space.subs(idxc.subs)
custom.memory_type.address_space.subs(idxc.subs) == GLOBAL_ADDRESS_SPACE
or custom.memory_type.address_space.subs(idxc.subs)
== tkl.AddressSpace.GLOBAL_MEMORY.value
)

Expand Down Expand Up @@ -824,3 +826,9 @@ def all_equal(input_list: list[Any]) -> bool:
if len(input_list) == 0:
return True
return all(elem == input_list[0] for elem in input_list)


def memory_to_register(memory_type: Memory) -> Register:
dtype = memory_type.dtype
shape = memory_type.symbolic_shape
return Register[*shape, dtype]
Loading

0 comments on commit a606651

Please sign in to comment.