Skip to content

Commit

Permalink
Fix the sharded llama test
Browse files Browse the repository at this point in the history
With this running `sharktank.examples.paged_llm_v1` with
`llama3_8B_fp16` produces sane results.
  • Loading branch information
sogartar committed Sep 25, 2024
1 parent 29d47f6 commit fb04099
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
18 changes: 18 additions & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def prefill(
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)

if self.config.tensor_parallelism_size > 1:
if not isinstance(tokens, ReplicatedTensor):
tokens = ops.replicate(
Expand All @@ -217,6 +218,12 @@ def prefill(
seq_block_ids = ops.replicate(
seq_block_ids, count=self.config.tensor_parallelism_size
)
# If the user provided unsharded arguments they probably want
# an unsharded result as well.
unshard_result = True
else:
unshard_result = False

h = self.token_embedding(tokens)
self.trace_tensor("llama.token_embedding", h)

Expand All @@ -236,6 +243,8 @@ def prefill(

h = self.output_norm(h)
logits = self.output_lm_head(h)
if unshard_result:
logits = ops.unshard(logits)
return logits

def decode(
Expand Down Expand Up @@ -268,6 +277,7 @@ def decode(
self._assert_device(attention_mask, dtype=self.activation_dtype)
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)

if self.config.tensor_parallelism_size > 1:
if not isinstance(tokens, ReplicatedTensor):
tokens = ops.replicate(
Expand All @@ -285,6 +295,12 @@ def decode(
seq_block_ids = ops.replicate(
seq_block_ids, count=self.config.tensor_parallelism_size
)
# If the user provided unsharded arguments they probably want
# an unsharded result as well.
unshard_result = True
else:
unshard_result = False

bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
Expand Down Expand Up @@ -360,6 +376,8 @@ def decode(

h = self.output_norm(h)
logits = self.output_lm_head(h)
if unshard_result:
logits = ops.unshard(logits)
return logits


Expand Down
1 change: 0 additions & 1 deletion sharktank/tests/layers/sharded_paged_kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def testSmallCache(self):
transformer_block_index=transformer_block_index,
page_ids=sharded_page_ids,
)
# TODO: debug this failure
for unsharded, sharded in zip(
read_into_partitions, sharded_read_into_partitions
):
Expand Down
37 changes: 19 additions & 18 deletions sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class AttentionBlockTest(unittest.TestCase):
@unittest.expectedFailure
def testToyModelCompareToUnsharded(self):
"""Run a sharded variant of a toy model size and compare it against the
unsharded variant."""
Expand Down Expand Up @@ -75,16 +74,17 @@ def testToyModelCompareToUnsharded(self):
batch_size * batch_seq_len // config.block_seq_stride
).view(batch_size, -1)

# Verify prefill step.
sharded_config = deepcopy(config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(theta, sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, sharded_config)
sharded_cache_state = sharded_model.cache.paged.allocate(
page_count=cache_page_count
)
sharded_cache_state = [
ops.reshard_like(cache_state_snapshot[0].clone(), sharded_cache_state[0])
]
sharded_cache_state = sharded_model.cache.paged.shard_state(
deepcopy(cache_state)
)

expected_prefill_result = model.prefill(
token_ids,
Expand All @@ -98,17 +98,20 @@ def testToyModelCompareToUnsharded(self):
seq_block_ids=seq_block_ids,
cache_state=sharded_cache_state,
)
actual_prefill_result = ops.unshard(sharded_prefill_result)
# The errors are quite high, but for float64 both errors drop to < 1e-12.
# The numerics are probably correct.
torch.testing.assert_close(
actual_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2
sharded_prefill_result, expected_prefill_result, atol=1e-3, rtol=1e-2
)
expected_cache_state = cache_state[0]
actual_cache_state = ops.unshard(sharded_cache_state[0])
# TODO: debug this numerical issue
torch.testing.assert_close(actual_cache_state, expected_cache_state)
actual_cache_state = ops.unshard(
sharded_model.cache.paged.unflatten_page_table(sharded_cache_state)
).flatten(start_dim=1)
torch.testing.assert_close(
actual_cache_state, expected_cache_state, atol=1e-4, rtol=1e-1
)

# Verify decode step.
decode_token_ids = torch.randint(
low=0,
high=vocabulary_size,
Expand All @@ -126,30 +129,28 @@ def testToyModelCompareToUnsharded(self):
model.input_mask(decode_seq_lens, decode_batch_seq_len)
)
decode_cache_state = deepcopy(cache_state_snapshot)
decode_sharded_cache_state = sharded_model.cache.paged.shard_state(
deepcopy(decode_cache_state)
)
expected_decode_result = model.decode(
decode_token_ids,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=decode_cache_state,
)
decode_sharded_cache_state = [
ops.reshard_like(cache_state_snapshot[0].clone(), sharded_cache_state[0])
]
sharded_decode_result = sharded_model.decode(
decode_token_ids,
attention_mask=decode_attention_mask,
start_positions=start_positions,
seq_block_ids=seq_block_ids,
cache_state=decode_sharded_cache_state,
)
actual_decode_result = ops.unshard(sharded_decode_result)
# TODO: debug this numerical issue
torch.testing.assert_close(actual_decode_result, expected_decode_result)

torch.testing.assert_close(sharded_decode_result, expected_decode_result)
expected_decode_cache_state = decode_cache_state[0]
actual_decode_cache_state = ops.unshard(decode_sharded_cache_state[0])
# TODO: debug this numerical issue
actual_decode_cache_state = ops.unshard(
sharded_model.cache.paged.unflatten_page_table(decode_sharded_cache_state)
).flatten(start_dim=1)
torch.testing.assert_close(
actual_decode_cache_state, expected_decode_cache_state
)

0 comments on commit fb04099

Please sign in to comment.