Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type inference #252

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 53 additions & 41 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def custom_string(self, value_map: dict[str, str]) -> str:
vars_str = ", ".join(vars_list)
return f"{self.tkw_op_name}({vars_str})"

def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
def add_to_graph(self, region_graph: RegionGraph, type: Any = None) -> fx.Node:
arg_list = tuple([value for _, value in vars(self).items()])
self.graph = region_graph
self.fx_node = region_graph.create_node(
Expand All @@ -350,6 +350,10 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
self.fx_node.tkw_op = self.__class__
self.fx_node.tkw_op_name = self.tkw_op_name
self.fx_node.index = None
if type is None:
get_custom(self.fx_node).infer_type()
else:
self.fx_node.type = type
return self.fx_node

def _add_proxy_to_graph(self, region_graph: RegionGraph):
Expand Down Expand Up @@ -556,6 +560,23 @@ def vector_shapes(self) -> dict[IndexSymbol, int]:
def vector_shapes(self, value: dict[IndexSymbol, int]):
self.fx_node.vector_shapes = value

@property
def type(self) -> Any:
if hasattr(self.fx_node, "type"):
return self.fx_node.type
return None

@type.setter
def type(self, value: Any):
self.fx_node.type = value

def infer_type(self):
"""
Infer the type of this operator using the types
of its arguments.
"""
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can raise NotImplemented here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. So the idea is that most operators will implement infer_types except for Placeholders, etc. For those operators, since they already have a type, we want this to just pass through instead of throwing an error.


def align_index(self, constraints: list["Constraint"]) -> None:
"""
Align index to WG/Tile sizes.
Expand Down Expand Up @@ -602,21 +623,21 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
def infer_type(self):
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
return lhs_type
self.type = lhs_type
return
lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
return broadcasted_type
self.type = broadcasted_type


@define_interface_op("exp2")
Expand All @@ -637,10 +658,9 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
def infer_type(self):
src_type = get_custom(self.arg).type
return src_type
self.type = src_type


@final
Expand Down Expand Up @@ -868,9 +888,8 @@ def rhs_type(self) -> Memory:
def acc_type(self) -> Memory:
return get_custom(self.acc).type

@property
def type(self) -> Memory:
return self.acc_type
def infer_type(self):
self.type = self.acc_type

def operand_index(
self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr]
Expand Down Expand Up @@ -925,6 +944,7 @@ def reduction_dim(self, value: IndexSymbol):
@define_op("read")
@dataclass
class Read(CustomOp):

memory: fx.Proxy
elements_per_thread: Optional[Any] = None
mapping: Optional[IndexMapping] = None
Expand All @@ -937,10 +957,9 @@ def indexing_dims(self) -> list[IndexSymbol]:
# TODO: This could contain ints.
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Register":
def infer_type(self):
dtype = self.memory_type.dtype
return Register[*self.indexing_dims, dtype]
self.type = Register[*self.indexing_dims, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1052,12 +1071,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:
captured_vars.append(nested_node)
return captured_vars

@property
def type(self) -> Memory | Register | list[Memory | Register]:
def infer_type(self):
res_types = [get_custom(x).type for x in self.init_args]
if len(res_types) == 1:
res_types = res_types[0]
return res_types
self.type = res_types

def outputs(self, graph: fx.Graph) -> list[fx.Node]:
for node in graph.nodes:
Expand Down Expand Up @@ -1110,11 +1128,12 @@ def indexing_dims(self) -> list[IndexSymbol]:
if self.mapping is not None:
return list(self.mapping.input_shape)
# TODO: This could contain ints.
return list(self.type.symbolic_shape)
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Memory":
return get_custom(self.memory).type
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1144,13 +1163,12 @@ class GetResult(CustomOp):
value: fx.Node
res_idx: int

@property
def type(self) -> "Memory":
def infer_type(self):
src_type = get_custom(self.value).type
if isinstance(src_type, list):
return src_type[self.res_idx]
self.type = src_type[self.res_idx]
else:
return src_type
self.type = src_type

@property
def indexing_dims(self) -> list[IndexExpr]:
Expand Down Expand Up @@ -1200,14 +1218,14 @@ class Extract(CustomOp):
register_: fx.Proxy
offset: IndexExpr | int

@property
def type(self) -> "Register":
def infer_type(self):
# Intuition here is we are trying to extract an element
# from fastest dim => we reduce the fastest dim.
src_type = get_custom(self.register_).type
# Return itself if just 0-D/1-D symbolic.
if len(src_type.symbolic_shape) <= 1:
return src_type
self.type = src_type
return

# Typically fastest dim is the last dimension,
# If non-unit dim exists => non-unit dim is fastest dim.
Expand All @@ -1220,7 +1238,7 @@ def type(self) -> "Register":
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
return dst_type
self.type = dst_type


@define_op("extract_slice")
Expand Down Expand Up @@ -1297,12 +1315,8 @@ def indexing_dims(self) -> list[IndexSymbol]:
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
def infer_type(self):
if isinstance(self.arg, Sequence):
# Local import to break circular dep.
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
ref_shape = src_types[0].symbolic_shape
ref_dtype = src_types[0].dtype
Expand All @@ -1318,7 +1332,7 @@ def type(self) -> Memory:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type
self.type = dst_type

@property
def num_reduction_dims(self) -> int:
Expand Down Expand Up @@ -1376,10 +1390,9 @@ class CastOp(CustomOp, ABC):
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]
self.type = Register[*src_shape, self.dtype]


@define_op("permute")
Expand All @@ -1397,13 +1410,12 @@ class Permute(CustomOp, ABC):
def indexing_dims(self) -> list[IndexExpr]:
return self.target_shape

@property
def type(self) -> Register:
def infer_type(self):
src_type = get_custom(self.arg).type
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]
self.type = Register[*self.target_shape, src_type.dtype]


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _expand_reduction(
# Add GetResult nodes for the corresponding dimensions
reduction.graph.inserting_after(reduction.fx_node)
new_node = GetResult(reduction.fx_node, len(new_output_args))
new_node.add_to_graph(reduction.graph)
new_node.add_to_graph(reduction.graph, arg.type)
Copy link
Contributor

@raikonenfnu raikonenfnu Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious what/why is the regular GetResult get_infer_type not working for this case? Can we add a comment to explain a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I can add that. The short summary is that it would require us to find the output node to get the right type since this is during expansion and the init_args have not been set yet. So since we have access to arg, we just use that instead.

new_node.fx_node.name = get_expanded_name(new_node, dims)
context[
(reduction, get_indexed_dims(dims, expand_dims), arg_idx)
Expand Down
21 changes: 21 additions & 0 deletions iree/turbine/kernel/wave/type_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import *
from .._support.tracing import CapturedTrace
import torch.fx as fx
from ...support.logging import get_logger

logger = get_logger("turbine.wave.type_inference")


def infer_types(trace: CapturedTrace | fx.Graph):
# Infer and set the types for all nodes in the graph.
for subgraph in trace.region_graph.subgraphs.values():
for node in subgraph.nodes:
custom = get_custom(node)
custom.infer_type()
logger.debug(f"Setting type for {custom.fx_node} = {custom.type}")
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
from .type_inference import infer_types
import iree.turbine.kernel.lang as tkl
from .._support.tracing import (
CapturedTrace,
Expand Down Expand Up @@ -224,6 +225,9 @@ def _trace_and_get_kernel_signature(
# Initialize Vector shapes
self.hardware_constraints[0].subs_vector_shapes(idxc.subs)

# Do type inference.
infer_types(graph)

# Promote the placeholders to the appropriate address space.
promote_placeholders(graph, self.constraints)
hoist_allocs(graph)
Expand Down
3 changes: 3 additions & 0 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_read_write_equal_sizes():
graph: fx.Graph = trace.get_root_graph()
read_node = get_read_nodes(graph)[0]
IndexingContext.current().finalize()
infer_types(trace)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
Expand Down Expand Up @@ -171,6 +173,7 @@ def test_gemm():
trace: CapturedTrace = gemm()
graph: fx.Graph = trace.get_subgraph("region_0")
IndexingContext.current().finalize()
infer_types(trace)
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
Expand Down
10 changes: 10 additions & 0 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.wave.index_sequence_analysis import (
set_node_indices,
set_post_expansion_indices,
Expand Down Expand Up @@ -69,6 +70,7 @@ def test_read_write_equal_sizes():
):
graph = read_write_same_size()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -150,6 +152,7 @@ def test_read_write():
):
graph = read_write_different_dims()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -227,6 +230,7 @@ def test_gemm():
):
graph = gemm()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -413,6 +417,7 @@ def test_batched_gemm():
):
graph = batched_gemm()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -591,6 +596,7 @@ def test_gemm_non_direct_acc():
):
graph = gemm_non_direct_acc()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -657,6 +663,7 @@ def test_tiled_max():
):
graph = tiled_max()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -688,6 +695,7 @@ def test_gemm_reduction_expansion_only():
):
graph = gemm()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -791,6 +799,7 @@ def py_arithmetic_different_dims():
):
graph = py_arithmetic_different_dims()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down Expand Up @@ -896,6 +905,7 @@ def test_chained_gemm_32x32x8():
):
graph = chained_gemm_32x32x8()
IndexingContext.current().finalize()
infer_types(graph)
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
Expand Down
Loading
Loading