-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor parallelism to the paged llama model #185
Conversation
971c5c9
to
f2e9e81
Compare
9d931c3
to
0f82101
Compare
d583e55
to
e896d79
Compare
@@ -123,3 +166,49 @@ def theta_sharding(self) -> ThetaSharding: | |||
), | |||
} | |||
) | |||
|
|||
|
|||
class LinearSplitReductionDimSharding(ThetaLayerSharding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge this with LinearSplitParallelWeightAndBiasSharding? Lots of this is repeated there with just minor bias and shard_dim differences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They shard different dimensions. This one shards the reduction dimension and the other shards the parallel dimension. Here also the bias is replicated and in the other case it is split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but seems like LinearSplitReductionDimSharding could be replaced with LinearSplitParallelWeightAndBiasSharding by setting self.weight_and_bias_spit_dim = 1 and adding another arg to set whether the bias is split or replicated instead of creating a new function
expected_cache_state = cache_state[0] | ||
actual_cache_state = ops.unshard(sharded_cache_state[0]) | ||
# TODO: debug this numerical issue | ||
# torch.testing.assert_close(actual_cache_state, expected_cache_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of commenting out to force passing can you xfail instead with comment of numerical issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
fdb35ab
to
cfa705a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove the WIP in the title before we merge? Otherwise just a nit comment but looks good to me overall
@@ -123,3 +166,49 @@ def theta_sharding(self) -> ThetaSharding: | |||
), | |||
} | |||
) | |||
|
|||
|
|||
class LinearSplitReductionDimSharding(ThetaLayerSharding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but seems like LinearSplitReductionDimSharding could be replaced with LinearSplitParallelWeightAndBiasSharding by setting self.weight_and_bias_spit_dim = 1 and adding another arg to set whether the bias is split or replicated instead of creating a new function
This adds one test that checks the sharded vs the unsharded veriants. Make `sharktank.examples.paged_llm_v1` support a tensor parallelism CLI option. This change adds a lot of sharded variants for PyTorch API-equivalent ops but some of them lack auto-testing. index_copy_, index_put_, slicing, flatten, unflatten and reshape have tests. Check that replication and splitting of un unsharded tensor is not an actual copy. It is probably unintuitive that when ran through PyTorch the sharded result shares the same memory. It may be better to change the semantics and require that it is actually a copy. During exporting this would insert copies that the compiler would need to optimize out. Add test for sharded paged KV cache.
df0bd7e
to
1063c60
Compare
No description provided.