Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Nov 9, 2024
1 parent 7904f5e commit da6d458
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 53 deletions.
10 changes: 4 additions & 6 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,8 @@ def custom_string(self, value_map: dict[str, str]) -> str:
def indexing_dims(self) -> list[IndexSymbol]:
return list(self._type.symbolic_shape) if self._type else []

@property
def type(self) -> "Memory":
return self._type
def infer_type(self):
self.fx_node.type = self._type


@dataclass
Expand Down Expand Up @@ -854,9 +853,8 @@ class NewRegister(CustomOp):
def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

@property
def type(self) -> "Register":
return Register[*self.shape, self.dtype]
def infer_type(self):
self.type = Register[*self.shape, self.dtype]


@define_op("mma")
Expand Down
80 changes: 47 additions & 33 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,19 @@ def add_nodes_by_schedule(
logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}")
custom_node = get_custom(node)
logger.debug(f"Node args: {node.args}")
preferred_stage = (
stage if pipelining_stage == PipelineStage.KERNEL else None
)
for arg in node.args:
if arg_context.contains_in_iteration(iteration, arg):
logger.debug(
f"Found arg: {arg} at iteration {iteration} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg)}."
f"Found arg: {arg} at iteration {iteration} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg, preferred_stage)}."
)
continue
new_node = custom_node.copy(
new_graph=reduction_graph,
arg_transform=lambda x: (
arg_context.get_from_iteration(iteration, x)
arg_context.get_from_iteration(iteration, x, preferred_stage)
if arg_context.contains_in_iteration(iteration, x)
else x
),
Expand All @@ -109,28 +112,38 @@ def add_nodes_by_schedule(
# Add scheduling parameters for debugging.
new_node.scheduling_parameters = node.scheduling_parameters
# Update the rotating registers and argument context for the current node (if applicable).
old_node = None
if node in rotating_registers:
rotating_registers[node].append(new_node.fx_node)
rotating_registers[node].popleft()
old_node = rotating_registers[node].popleft()
# If draining, then override the rotating registers and update the argument context.
if fill_or_drain:
for next_stage in range(stage + 1, len(stages)):
arg_context[(iteration, next_stage, node)] = new_node.fx_node

# Update the init args in the argument context whenever a result is computed.
# Update the iter and init args in the argument context whenever a result is computed.
if node in arg_context.results:
iter_arg = arg_context.result_to_iter_arg[node]
logger.debug(
f"Updating result: {node} -> {iter_arg} to {new_node.fx_node}."
)
arg_context.map_arg_all_after_iteration(
iter_arg,
new_node.fx_node,
iteration,
)
if (
pipelining_stage == PipelineStage.EPILOGUE
or pipelining_stage == PipelineStage.KERNEL
pipelining_stage == PipelineStage.KERNEL
or pipelining_stage == PipelineStage.PROLOGUE
):
logger.debug(
f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}."
)
arg_context.map_arg_all_after_iteration(
arg_context.result_to_iter_arg[node],
new_node.fx_node,
iteration,
)
if iter_arg in rotating_registers and old_node:
logger.debug(
f"Updating rotating register iter arg {iter_arg} -> {old_node}."
)
rotating_registers[iter_arg].append(old_node)
rotating_registers[iter_arg].popleft()
for next_stage in range(stage + 1, len(stages)):
arg_context[(iteration, next_stage, iter_arg)] = old_node
if pipelining_stage == PipelineStage.PROLOGUE:
logger.debug(
f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}."
Expand All @@ -140,14 +153,6 @@ def add_nodes_by_schedule(
new_node.fx_node,
iteration,
)
logger.debug(
f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}."
)
arg_context.map_arg_all_after_iteration(
arg_context.result_to_iter_arg[node],
new_node.fx_node,
iteration,
)

if pipelining_stage == PipelineStage.KERNEL and use_scheduling_barriers:
SchedulingGroupBarrier(instructions, 0).add_to_graph(reduction_graph)
Expand Down Expand Up @@ -213,6 +218,7 @@ def construct_prologue(
)

# Map iter args to init args in the prologue.
original_init_args = list(reduction.init_args)
for iter_arg, init_arg in zip(
reduction.iter_args(reduction_subgraph), reduction.init_args
):
Expand Down Expand Up @@ -243,6 +249,14 @@ def construct_prologue(
new_init_args.append(mapped_init_arg)
reduction.init_args = new_init_args

# We may also have some iter_args as rotating registers. These will need
# to be initialized to the original init args which we do here.
iter_args = reduction.iter_args(reduction_subgraph)
for node, registers in rotating_registers.items():
if node in iter_args:
if all(x is None for x in registers) and len(registers) == 1:
registers[0] = original_init_args[iter_args.index(node)]

# Add missing registers. Since registers are not present
# in the scheduling code, we could end up with a situation where
# we move mma ops outside the reduction that do not have a corresponding
Expand Down Expand Up @@ -305,26 +319,25 @@ def push_rotating_registers(
custom = get_custom(node)
stage = custom.scheduling_parameters["stage"]
iteration = arg_context.get_kernel_iteration(stage)
arg_context[(iteration, stage, node)] = registers[-1]
if node not in arg_context.iter_args:
arg_context[(iteration, stage, node)] = registers[-1]
for i, register in enumerate(registers):
if create_new_nodes:
mapped_stage = stage + len(registers) - i
mapped_iteration = arg_context.get_kernel_iteration(mapped_stage)
iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph)
iter_arg.type = get_custom(node).type
iter_arg.index = get_custom(node).index
arg_context[(mapped_iteration, mapped_stage, node)] = iter_arg
new_registers.append(iter_arg)
logger.debug(
f"Mapped orig: {node_map[node]} / mapped: {iter_arg} to stage {mapped_stage}."
)
mapped_value = iter_arg
else:
mapped_stage = stage + len(registers) - i - 1
mapped_iteration = arg_context.get_kernel_iteration(mapped_stage)
arg_context[(mapped_iteration, mapped_stage, node)] = register
logger.debug(
f"Mapped orig: {node_map[node]} / mapped: {register} to stage {mapped_stage} at iteration {mapped_iteration}."
)
mapped_value = register
arg_context[(mapped_iteration, mapped_stage, node)] = mapped_value
logger.debug(
f"Mapped orig: {node_map[node]} / mapped: {mapped_value} to stage {mapped_stage} at iteration {mapped_iteration}."
)
count += 1
if new_registers:
new_rotating_registers[node] = new_registers
Expand Down Expand Up @@ -505,7 +518,8 @@ def construct_epilogue(
for i in range(len(flattened_rotating_registers)):
rotating_registers_get_results.append(
GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph(
pipelined_reduction.graph, type=flattened_rotating_registers[i].type
pipelined_reduction.graph,
type=flattened_rotating_registers[i].type,
)
)
rotating_registers = unflatten_dict_values(
Expand Down Expand Up @@ -566,7 +580,7 @@ def construct_pipelined_loop(
with a prologue, kernel and epilogue.
"""
induction_variable = get_induction_variable(reduction, constraints)
num_rotating_registers = liveness_analysis(graph, constraints, scheduler)
num_rotating_registers = liveness_analysis(graph, reduction)
rotating_registers: dict[fx.Node, deque[fx.Node]] = {
k: deque([None for _ in range(v)]) for k, v in num_rotating_registers.items()
}
Expand Down
17 changes: 9 additions & 8 deletions iree/turbine/kernel/wave/scheduling/loop_reconstruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def lookup(self, key: fx.Node) -> Optional[fx.Node]:
"""
Looks up the argument mapping for the given node.
"""
for iteration in range(self.num_iterations):
for iteration in range(self.num_iterations - 1, -1, -1):
for stage in range(self.num_stages):
if key in self.argument_map[iteration][stage]:
return self.argument_map[iteration][stage][key]
Expand All @@ -174,10 +174,15 @@ def contains_in_iteration(self, iteration: int, key: fx.Node) -> bool:
for stage in range(self.num_stages)
)

def get_from_iteration(self, iteration: int, key: fx.Node) -> fx.Node:
def get_from_iteration(self, iteration: int, key: fx.Node, stage: int) -> fx.Node:
"""
Gets the argument mapping for the given iteration.
Gets the argument mapping for the given iteration with
preference to the given stage.
"""

if stage and key in self.argument_map[iteration][stage]:
return self.argument_map[iteration][stage][key]

for stage in range(self.num_stages):
if key in self.argument_map[iteration][stage]:
return self.argument_map[iteration][stage][key]
Expand Down Expand Up @@ -226,9 +231,7 @@ def create_drain_stage_schedule(n: int) -> list[list[int]]:
return schedule


def liveness_analysis(
graph: fx.Graph, constraints: list[Constraint], scheduler: ModuloScheduler
) -> dict[fx.Node, int]:
def liveness_analysis(graph: fx.Graph, reduction: Reduction) -> dict[fx.Node, int]:
"""
Perform liveness analysis on the graph to determine the live ranges of
variables and use that to deduce how many rotating registers we need.
Expand All @@ -240,8 +243,6 @@ def liveness_analysis(
continue
if node not in lifetime:
lifetime[node] = 0
if isinstance(custom, Placeholder):
continue
for user in custom.users:
if user.scheduling_parameters is None:
continue
Expand Down
6 changes: 5 additions & 1 deletion iree/turbine/kernel/wave/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ def visualize_mapped_graphs(

# Draw edges between rotating registers for the same variable.
for node in rotating_registers:
all_registers = [k for k, v in flat_inverse_map.items() if v == node]
all_registers = [
k
for k, v in flat_inverse_map.items()
if v == node and k in second_numbering
]
for second, first in zip(all_registers[:-1], all_registers[1:]):
G.add_edge(
second_numbering[id(first)],
Expand Down
12 changes: 7 additions & 5 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@run_test
def test_attention_pipelined():
shape = (8, 128, 128, 32, 256)
shape = (8, 128, 128, 32, 96)
# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
Expand Down Expand Up @@ -108,16 +108,16 @@ def repeat(
LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
BLOCK_B: 1,
BLOCK_M: 32,
BLOCK_N: 32,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K2: 32,
B: shape[0],
M: shape[1],
N: shape[2],
K1: shape[3],
K2: shape[4],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_SHARED_DELAY: 2,
WRITE_SHARED_DELAY: 2,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
Expand Down Expand Up @@ -326,7 +326,9 @@ def repeat(
):
imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# b_reg: tkw.Register[B, N, K, tkl.f16]
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# acc: tkw.Register[B, N, M, tkl.f32]
inner_acc = tkw.mma(k_reg, q_reg, imm_reg)
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
m_j = tkw.max(x_j, partial_max, dim=K2)
Expand Down

0 comments on commit da6d458

Please sign in to comment.