Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism to the paged llama model #185

Merged
merged 4 commits into from
Sep 25, 2024

Conversation

sogartar
Copy link
Contributor

No description provided.

@sogartar sogartar force-pushed the llama-sharding branch 10 times, most recently from 971c5c9 to f2e9e81 Compare September 20, 2024 11:47
@sogartar sogartar force-pushed the llama-sharding branch 7 times, most recently from 9d931c3 to 0f82101 Compare September 23, 2024 18:34
@sogartar sogartar marked this pull request as ready for review September 23, 2024 18:40
sharktank/sharktank/models/llama/sharding.py Show resolved Hide resolved
sharktank/sharktank/types/sharding.py Outdated Show resolved Hide resolved
@@ -123,3 +166,49 @@ def theta_sharding(self) -> ThetaSharding:
),
}
)


class LinearSplitReductionDimSharding(ThetaLayerSharding):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this with LinearSplitParallelWeightAndBiasSharding? Lots of this is repeated there with just minor bias and shard_dim differences

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They shard different dimensions. This one shards the reduction dimension and the other shards the parallel dimension. Here also the bias is replicated and in the other case it is split.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but seems like LinearSplitReductionDimSharding could be replaced with LinearSplitParallelWeightAndBiasSharding by setting self.weight_and_bias_spit_dim = 1 and adding another arg to set whether the bias is split or replicated instead of creating a new function

expected_cache_state = cache_state[0]
actual_cache_state = ops.unshard(sharded_cache_state[0])
# TODO: debug this numerical issue
# torch.testing.assert_close(actual_cache_state, expected_cache_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of commenting out to force passing can you xfail instead with comment of numerical issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove the WIP in the title before we merge? Otherwise just a nit comment but looks good to me overall

sharktank/sharktank/models/llama/sharding.py Show resolved Hide resolved
@@ -123,3 +166,49 @@ def theta_sharding(self) -> ThetaSharding:
),
}
)


class LinearSplitReductionDimSharding(ThetaLayerSharding):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but seems like LinearSplitReductionDimSharding could be replaced with LinearSplitParallelWeightAndBiasSharding by setting self.weight_and_bias_spit_dim = 1 and adding another arg to set whether the bias is split or replicated instead of creating a new function

This adds one test that checks the sharded vs the unsharded
veriants.

Make `sharktank.examples.paged_llm_v1` support a tensor parallelism
CLI option.

This change adds a lot of sharded variants for PyTorch API-equivalent
ops but some of them lack auto-testing.
index_copy_, index_put_, slicing, flatten, unflatten and reshape have tests.

Check that replication and splitting of un unsharded tensor is not an
actual copy. It is probably unintuitive that when ran through PyTorch
the sharded result shares the same memory.
It may be better to change the semantics and require that it is actually
a copy. During exporting this would insert copies that the compiler
would need to optimize out.

Add test for sharded paged KV cache.
@sogartar sogartar changed the title [WIP] add tensor parallelism to the paged llama model Add tensor parallelism to the paged llama model Sep 25, 2024
@sogartar sogartar merged commit a9d3d41 into nod-ai:main Sep 25, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants