diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 6ac91c704..9c8057bf3 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -204,6 +204,7 @@ def prefill( self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) + if self.config.tensor_parallelism_size > 1: if not isinstance(tokens, ReplicatedTensor): tokens = ops.replicate( @@ -217,6 +218,12 @@ def prefill( seq_block_ids = ops.replicate( seq_block_ids, count=self.config.tensor_parallelism_size ) + # If the user provided unsharded arguments they probably want + # an unsharded result as well. + unshard_result = True + else: + unshard_result = False + h = self.token_embedding(tokens) self.trace_tensor("llama.token_embedding", h) @@ -236,6 +243,8 @@ def prefill( h = self.output_norm(h) logits = self.output_lm_head(h) + if unshard_result: + logits = ops.unshard(logits) return logits def decode( @@ -268,6 +277,7 @@ def decode( self._assert_device(attention_mask, dtype=self.activation_dtype) self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) + if self.config.tensor_parallelism_size > 1: if not isinstance(tokens, ReplicatedTensor): tokens = ops.replicate( @@ -285,6 +295,12 @@ def decode( seq_block_ids = ops.replicate( seq_block_ids, count=self.config.tensor_parallelism_size ) + # If the user provided unsharded arguments they probably want + # an unsharded result as well. + unshard_result = True + else: + unshard_result = False + bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. @@ -360,6 +376,8 @@ def decode( h = self.output_norm(h) logits = self.output_lm_head(h) + if unshard_result: + logits = ops.unshard(logits) return logits diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index 219fb4642..c835a9682 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -114,7 +114,6 @@ def testSmallCache(self): transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, ) - # TODO: debug this failure for unsharded, sharded in zip( read_into_partitions, sharded_read_into_partitions ): diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 1c25a316e..41671208e 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -16,7 +16,6 @@ class AttentionBlockTest(unittest.TestCase): - @unittest.expectedFailure def testToyModelCompareToUnsharded(self): """Run a sharded variant of a toy model size and compare it against the unsharded variant.""" @@ -75,6 +74,7 @@ def testToyModelCompareToUnsharded(self): batch_size * batch_seq_len // config.block_seq_stride ).view(batch_size, -1) + # Verify prefill step. sharded_config = deepcopy(config) sharded_config.tensor_parallelism_size = 2 sharded_theta = shard_theta(theta, sharded_config) @@ -82,9 +82,9 @@ def testToyModelCompareToUnsharded(self): sharded_cache_state = sharded_model.cache.paged.allocate( page_count=cache_page_count ) - sharded_cache_state = [ - ops.reshard_like(cache_state_snapshot[0].clone(), sharded_cache_state[0]) - ] + sharded_cache_state = sharded_model.cache.paged.shard_state( + deepcopy(cache_state) + ) expected_prefill_result = model.prefill( token_ids, @@ -98,17 +98,20 @@ def testToyModelCompareToUnsharded(self): seq_block_ids=seq_block_ids, cache_state=sharded_cache_state, ) - actual_prefill_result = ops.unshard(sharded_prefill_result) # The errors are quite high, but for float64 both errors drop to < 1e-12. # The numerics are probably correct. torch.testing.assert_close( - actual_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2 + sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2 ) expected_cache_state = cache_state[0] - actual_cache_state = ops.unshard(sharded_cache_state[0]) - # TODO: debug this numerical issue - torch.testing.assert_close(actual_cache_state, expected_cache_state) + actual_cache_state = ops.unshard( + sharded_model.cache.paged.unflatten_page_table(sharded_cache_state) + ).flatten(start_dim=1) + torch.testing.assert_close( + actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1 + ) + # Verify decode step. decode_token_ids = torch.randint( low=0, high=vocabulary_size, @@ -126,6 +129,9 @@ def testToyModelCompareToUnsharded(self): model.input_mask(decode_seq_lens, decode_batch_seq_len) ) decode_cache_state = deepcopy(cache_state_snapshot) + decode_sharded_cache_state = sharded_model.cache.paged.shard_state( + deepcopy(decode_cache_state) + ) expected_decode_result = model.decode( decode_token_ids, attention_mask=decode_attention_mask, @@ -133,9 +139,6 @@ def testToyModelCompareToUnsharded(self): seq_block_ids=seq_block_ids, cache_state=decode_cache_state, ) - decode_sharded_cache_state = [ - ops.reshard_like(cache_state_snapshot[0].clone(), sharded_cache_state[0]) - ] sharded_decode_result = sharded_model.decode( decode_token_ids, attention_mask=decode_attention_mask, @@ -143,13 +146,11 @@ def testToyModelCompareToUnsharded(self): seq_block_ids=seq_block_ids, cache_state=decode_sharded_cache_state, ) - actual_decode_result = ops.unshard(sharded_decode_result) - # TODO: debug this numerical issue - torch.testing.assert_close(actual_decode_result, expected_decode_result) - + torch.testing.assert_close(sharded_decode_result, expected_decode_result) expected_decode_cache_state = decode_cache_state[0] - actual_decode_cache_state = ops.unshard(decode_sharded_cache_state[0]) - # TODO: debug this numerical issue + actual_decode_cache_state = ops.unshard( + sharded_model.cache.paged.unflatten_page_table(decode_sharded_cache_state) + ).flatten(start_dim=1) torch.testing.assert_close( actual_decode_cache_state, expected_decode_cache_state )