Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scheduling in attention operators #253

Merged
merged 2 commits into from
Nov 14, 2024

Conversation

harsh-nod
Copy link
Contributor

@harsh-nod harsh-nod commented Nov 5, 2024

This PR adds support for scheduling in
attention operators. In particular, the following
changes are implemented:

  1. Parallel Floyd Warshall allows for faster scheduling
  2. Support for iter_args in rotating registers
  3. Support for VALU and SHUFFLE delays and resources
  4. Add infer_type for remaining wave ops

@harsh-nod harsh-nod force-pushed the fa_sched branch 18 times, most recently from 51d25a6 to dd182ab Compare November 10, 2024 20:33
This PR adds support for scheduling in
attention operators. In particular, the following
changes are implemented:

1. Parallel Floyd Warshall allows for faster scheduling
2. Support for iter_args in rotating registers
3. Support for VALU and SHUFFLE delays and resources
4. Add infer_type for remaining wave ops

Signed-off-by: Harsh Menon <[email protected]>
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a few comments

iree/turbine/kernel/wave/scheduling/loop_reconstruction.py Outdated Show resolved Hide resolved
iree/turbine/kernel/wave/scheduling/graph_utils.py Outdated Show resolved Hide resolved
iree/turbine/kernel/ops/wave_ops.py Show resolved Hide resolved
@@ -227,11 +249,11 @@ def liveness_analysis(
logger.debug(
f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}"
)
lifetime[node] = max(
user_lifetime = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a purpose to this refactoring aside from styling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was doing something with it, but dropped and decided to keep it like this for styling.

Operation.WRITE_GLOBAL: np.array([[1, 0, 0]]),
Operation.MMA: np.array([[0, 0, 1]]),
Operation.NOOP: np.array([[0, 0, 0]]),
Operation.READ_SHARED: np.array([[0, 1, 0, 0, 0]]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but does this mean we cannot read and write at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just specifies how many resources the instruction uses in one cycle. We can have multiple reads/writes by increasing the total number of resources available.

if (
pipelining_stage == PipelineStage.KERNEL
or pipelining_stage == PipelineStage.EPILOGUE
or pipelining_stage == PipelineStage.PROLOGUE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this change little bit more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will also add a comment.
# In situations where we have an iter_arg as a rotating register,
# we also have the output as a rotating register. So when we
# are updating the output, we update the iter_arg as well with the
# old value of the output rotating register. Consider this example:
# Say we have the following:
#
# Stage 0:
# iter_arg0
#
#
# output = compute(...) -> here we update iter_arg0 to have the output value
# for the next stage, so that it gets picked up in stage1.
#
# Stage 1:
# b = use(iter_arg0)

new_node.index[dim] = new_node.index[dim].subs(
{induction_variable: current_induction_variables[iteration]}
)
if new_node.index:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these for the new ops such as Register which doesnt have index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hit this for the ExtractOps that don't have an index.

new_node.index[dim] = new_node.index[dim].subs(
{induction_variable: current_induction_variables[iteration]}
)
if custom_node.expanded_dims:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this is needed now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expanded dims are needed for the reshape ops during codegen.

Signed-off-by: Harsh Menon <[email protected]>
Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, just a few comments on parallelism and subs

D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pool creation/destruction can be expensive (as it starts new processes and then wait them to die in close/join). We probably should have some shared global pool (created on demand).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes sense. Right now, we instantiate this pool multiple times within the modulo scheduling loop. Instead, we could do this just once in the constructor and tear it down in the destructor. How does that sound?

pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can potentially parallelize it even more by having 2 loops: 1st calling pool.map_async and 2nd aggregating results from them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, that makes sense. thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tried this and you cannot do this because each iteration depends on the results of the previous iteration. So all the iterations of the loop cannot use the same value of the D matrix but after every iteration, we need to update the D matrix and then use the updated D matrix.

@@ -190,7 +257,8 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
D_static[key] = D_static[key].subs(T, initiation_interval)
if isinstance(D_static[key], sympy.Expr):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have utils.safe_subs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure will change.

@harsh-nod harsh-nod merged commit 79ec575 into iree-org:main Nov 14, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants