Skip to content

Commit

Permalink
Address Rob's PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Oct 14, 2024
1 parent 7fac89d commit d6f1cfd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 65 deletions.
38 changes: 0 additions & 38 deletions sharktank/sharktank/layers/sharding.py

This file was deleted.

30 changes: 29 additions & 1 deletion sharktank/sharktank/models/llama/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,35 @@
from ...types.sharding import *
from ...types import Theta
from ... import ops
from ...layers.sharding import PagedLlamaAttentionBlockSharding


class PagedLlamaAttentionBlockSharding(ThetaLayerSharding):
def __init__(self, shard_count: int):
super().__init__()
self.shard_count = shard_count

def theta_sharding(self) -> ThetaSharding:
return ThetaSharding(
{
# The size of this is the token embedding length, which is not a memory
# space concern if replicated even for all attention blocks.
"attn_norm": RmsNormReplicatedSharding(
self.shard_count
).theta_sharding(),
"attn_q": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"attn_k": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"attn_v": LinearSplitParallelWeightAndBiasSharding(
shard_count=self.shard_count
).theta_sharding(),
"attn_output": LinearSplitReductionDimSharding(
shard_count=self.shard_count
).theta_sharding(),
}
)


class AttentionFFNBlockSharding(ThetaLayerSharding):
Expand Down
53 changes: 27 additions & 26 deletions sharktank/tests/layers/sharded_paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RotaryEmbeddingLayer,
)
from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch
from sharktank.layers.sharding import PagedLlamaAttentionBlockSharding
from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding
from sharktank.types import SplitPrimitiveTensor, unbox_tensor
import torch
from sharktank import ops
Expand Down Expand Up @@ -84,6 +84,23 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[
sharded_cache_state,
) = make_unsharded_and_sharded_equal_cache_states()

input_tensor = make_rand_torch(
(
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
),
dtype=dtype,
)
seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)
embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

theta = make_llama_attention_block_theta(
head_count=self.attention_head_count,
head_count_kv=self.head_count_kv,
Expand All @@ -99,23 +116,6 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
)
embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

input_tensor = make_rand_torch(
(
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
),
dtype=dtype,
)
seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)
expected_result = attention_block(
input_tensor,
embedding=embedding_module,
Expand All @@ -124,6 +124,15 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[
cache_state=cache_state,
)

sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count)
sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count)
sharded_embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
tensor_parallelism_size=self.shard_count,
)

theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count)
sharded_theta = ops.reshard(theta, theta_sharding)
sharded_attention_block = PagedLlamaAttentionBlock(
Expand All @@ -135,14 +144,6 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
)
sharded_embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
tensor_parallelism_size=self.shard_count,
)
sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count)
sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count)
sharded_result = sharded_attention_block(
sharded_input_tensor,
embedding=sharded_embedding_module,
Expand Down

0 comments on commit d6f1cfd

Please sign in to comment.