diff --git a/sharktank/sharktank/layers/sharding.py b/sharktank/sharktank/layers/sharding.py deleted file mode 100644 index 0e2dfe012..000000000 --- a/sharktank/sharktank/layers/sharding.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Specifications describing how layers are sharded.""" - -from ..types.sharding import * - - -class PagedLlamaAttentionBlockSharding(ThetaLayerSharding): - def __init__(self, shard_count: int): - super().__init__() - self.shard_count = shard_count - - def theta_sharding(self) -> ThetaSharding: - return ThetaSharding( - { - # The size of this is the token embedding length, which is not a memory - # space concern if replicated even for all attention blocks. - "attn_norm": RmsNormReplicatedSharding( - self.shard_count - ).theta_sharding(), - "attn_q": LinearSplitParallelWeightAndBiasSharding( - shard_count=self.shard_count - ).theta_sharding(), - "attn_k": LinearSplitParallelWeightAndBiasSharding( - shard_count=self.shard_count - ).theta_sharding(), - "attn_v": LinearSplitParallelWeightAndBiasSharding( - shard_count=self.shard_count - ).theta_sharding(), - "attn_output": LinearSplitReductionDimSharding( - shard_count=self.shard_count - ).theta_sharding(), - } - ) diff --git a/sharktank/sharktank/models/llama/sharding.py b/sharktank/sharktank/models/llama/sharding.py index a3559422c..3715a3923 100644 --- a/sharktank/sharktank/models/llama/sharding.py +++ b/sharktank/sharktank/models/llama/sharding.py @@ -9,7 +9,35 @@ from ...types.sharding import * from ...types import Theta from ... import ops -from ...layers.sharding import PagedLlamaAttentionBlockSharding + + +class PagedLlamaAttentionBlockSharding(ThetaLayerSharding): + def __init__(self, shard_count: int): + super().__init__() + self.shard_count = shard_count + + def theta_sharding(self) -> ThetaSharding: + return ThetaSharding( + { + # The size of this is the token embedding length, which is not a memory + # space concern if replicated even for all attention blocks. + "attn_norm": RmsNormReplicatedSharding( + self.shard_count + ).theta_sharding(), + "attn_q": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_k": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_v": LinearSplitParallelWeightAndBiasSharding( + shard_count=self.shard_count + ).theta_sharding(), + "attn_output": LinearSplitReductionDimSharding( + shard_count=self.shard_count + ).theta_sharding(), + } + ) class AttentionFFNBlockSharding(ThetaLayerSharding): diff --git a/sharktank/tests/layers/sharded_paged_llama_attention_block.py b/sharktank/tests/layers/sharded_paged_llama_attention_block.py index 7f9df87bf..c94fd44ab 100644 --- a/sharktank/tests/layers/sharded_paged_llama_attention_block.py +++ b/sharktank/tests/layers/sharded_paged_llama_attention_block.py @@ -11,7 +11,7 @@ RotaryEmbeddingLayer, ) from sharktank.layers.testing import make_llama_attention_block_theta, make_rand_torch -from sharktank.layers.sharding import PagedLlamaAttentionBlockSharding +from sharktank.models.llama.sharding import PagedLlamaAttentionBlockSharding from sharktank.types import SplitPrimitiveTensor, unbox_tensor import torch from sharktank import ops @@ -84,6 +84,23 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[ sharded_cache_state, ) = make_unsharded_and_sharded_equal_cache_states() + input_tensor = make_rand_torch( + ( + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ), + dtype=dtype, + ) + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + theta = make_llama_attention_block_theta( head_count=self.attention_head_count, head_count_kv=self.head_count_kv, @@ -99,23 +116,6 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[ head_count_kv=self.head_count_kv, rms_epsilon=self.rms_epsilon, ) - embedding_module = RotaryEmbeddingLayer( - rope_dimension_count=self.rope_dimension_count, - max_seqlen=self.max_seqlen, - rope_freq_base=self.rope_freq_base, - ) - - input_tensor = make_rand_torch( - ( - self.batch_size, - self.max_seqlen, - self.attention_head_count * self.attention_head_dim, - ), - dtype=dtype, - ) - seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( - self.batch_size, -1 - ) expected_result = attention_block( input_tensor, embedding=embedding_module, @@ -124,6 +124,15 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[ cache_state=cache_state, ) + sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count) + sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count) + sharded_embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + tensor_parallelism_size=self.shard_count, + ) + theta_sharding = PagedLlamaAttentionBlockSharding(shard_count=self.shard_count) sharded_theta = ops.reshard(theta, theta_sharding) sharded_attention_block = PagedLlamaAttentionBlock( @@ -135,14 +144,6 @@ def make_unsharded_and_sharded_equal_cache_states() -> tuple[ head_count_kv=self.head_count_kv, rms_epsilon=self.rms_epsilon, ) - sharded_embedding_module = RotaryEmbeddingLayer( - rope_dimension_count=self.rope_dimension_count, - max_seqlen=self.max_seqlen, - rope_freq_base=self.rope_freq_base, - tensor_parallelism_size=self.shard_count, - ) - sharded_input_tensor = ops.replicate(input_tensor, count=self.shard_count) - sharded_seq_block_ids = ops.replicate(seq_block_ids, count=self.shard_count) sharded_result = sharded_attention_block( sharded_input_tensor, embedding=sharded_embedding_module,