-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for scheduling in attention operators #253
Conversation
51d25a6
to
dd182ab
Compare
This PR adds support for scheduling in attention operators. In particular, the following changes are implemented: 1. Parallel Floyd Warshall allows for faster scheduling 2. Support for iter_args in rotating registers 3. Support for VALU and SHUFFLE delays and resources 4. Add infer_type for remaining wave ops Signed-off-by: Harsh Menon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a few comments
@@ -227,11 +249,11 @@ def liveness_analysis( | |||
logger.debug( | |||
f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}" | |||
) | |||
lifetime[node] = max( | |||
user_lifetime = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a purpose to this refactoring aside from styling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was doing something with it, but dropped and decided to keep it like this for styling.
Operation.WRITE_GLOBAL: np.array([[1, 0, 0]]), | ||
Operation.MMA: np.array([[0, 0, 1]]), | ||
Operation.NOOP: np.array([[0, 0, 0]]), | ||
Operation.READ_SHARED: np.array([[0, 1, 0, 0, 0]]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but does this mean we cannot read and write at the same time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just specifies how many resources the instruction uses in one cycle. We can have multiple reads/writes by increasing the total number of resources available.
if ( | ||
pipelining_stage == PipelineStage.KERNEL | ||
or pipelining_stage == PipelineStage.EPILOGUE | ||
or pipelining_stage == PipelineStage.PROLOGUE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain this change little bit more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will also add a comment.
# In situations where we have an iter_arg as a rotating register,
# we also have the output as a rotating register. So when we
# are updating the output, we update the iter_arg as well with the
# old value of the output rotating register. Consider this example:
# Say we have the following:
#
# Stage 0:
# iter_arg0
#
#
# output = compute(...) -> here we update iter_arg0 to have the output value
# for the next stage, so that it gets picked up in stage1.
#
# Stage 1:
# b = use(iter_arg0)
new_node.index[dim] = new_node.index[dim].subs( | ||
{induction_variable: current_induction_variables[iteration]} | ||
) | ||
if new_node.index: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these for the new ops such as Register which doesnt have index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hit this for the ExtractOps that don't have an index.
new_node.index[dim] = new_node.index[dim].subs( | ||
{induction_variable: current_induction_variables[iteration]} | ||
) | ||
if custom_node.expanded_dims: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why this is needed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expanded dims are needed for the reshape ops during codegen.
Signed-off-by: Harsh Menon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, just a few comments on parallelism and subs
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T | ||
|
||
# Parallel implementation | ||
pool = mp.get_context("fork").Pool(processes=mp.cpu_count()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pool creation/destruction can be expensive (as it starts new processes and then wait them to die in close
/join
). We probably should have some shared global pool (created on demand).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that makes sense. Right now, we instantiate this pool multiple times within the modulo scheduling loop. Instead, we could do this just once in the constructor and tear it down in the destructor. How does that sound?
pool = mp.get_context("fork").Pool(processes=mp.cpu_count()) | ||
for k in range(N): | ||
func = partial(all_pairs_longest_path_parallel, N, D, k) | ||
results = pool.map(func, range(N)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can potentially parallelize it even more by having 2 loops: 1st calling pool.map_async
and 2nd aggregating results from them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, that makes sense. thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I tried this and you cannot do this because each iteration depends on the results of the previous iteration. So all the iterations of the loop cannot use the same value of the D matrix but after every iteration, we need to update the D matrix and then use the updated D matrix.
@@ -190,7 +257,8 @@ def evaluate_all_pairs_longest_paths( | |||
""" | |||
D_static = dict(D) | |||
for key in D_static: | |||
D_static[key] = D_static[key].subs(T, initiation_interval) | |||
if isinstance(D_static[key], sympy.Expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have utils.safe_subs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure will change.
This PR adds support for scheduling in
attention operators. In particular, the following
changes are implemented: