diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index c835a9682..621f3ac13 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -9,93 +9,116 @@ import torch from sharktank.utils import iterables_equal from copy import deepcopy +from typing import List, Tuple from sharktank import ops +from sharktank.types import SplitPrimitiveTensor class ShardedPagedKVCacheTest(unittest.TestCase): + """Verify that the sharded paged KV cache behaves as the unsharded variant.""" + def setUp(self): torch.manual_seed(12345) - torch.set_default_dtype(torch.float32) - - def testSmallCache(self): - """Verify that the sharded paged KV cache behaves as the unsharded variant.""" - shard_count = 3 - transformer_block_count = 5 - attn_head_count = shard_count * 7 - block_seq_stride = 19 - attn_head_dim = 17 - cache_partition_count = 2 - page_count = 23 - dtype = torch.float32 - batch_size = 11 - block_seq_len = 13 - max_seq_len = block_seq_len * block_seq_stride - - cache = PagedKVCache( - transformer_block_count=transformer_block_count, - attn_head_count=attn_head_count, - block_seq_stride=block_seq_stride, - attn_head_dim=attn_head_dim, - cache_partition_count=cache_partition_count, - dtype=dtype, - ) - sharded_cache = PagedKVCache( - shard_count=shard_count, - transformer_block_count=transformer_block_count, - attn_head_count=attn_head_count, - block_seq_stride=block_seq_stride, - attn_head_dim=attn_head_dim, - cache_partition_count=cache_partition_count, - dtype=dtype, - ) - - # Test allocate. - cache_state = cache.allocate(page_count) - sharded_cache_state = sharded_cache.allocate(page_count) + self.dtype = torch.float32 + torch.set_default_dtype(self.dtype) + self.shard_count = 3 + self.transformer_block_count = 5 + self.attn_head_count = self.shard_count * 7 + self.block_seq_stride = 19 + self.attn_head_dim = 17 + self.cache_partition_count = 2 + self.page_count = 23 + self.batch_size = 11 + self.block_seq_len = 2 + self.max_seq_len = self.block_seq_len * self.block_seq_stride + + self.cache = PagedKVCache( + transformer_block_count=self.transformer_block_count, + attn_head_count=self.attn_head_count, + block_seq_stride=self.block_seq_stride, + attn_head_dim=self.attn_head_dim, + cache_partition_count=self.cache_partition_count, + dtype=self.dtype, + ) + self.sharded_cache = PagedKVCache( + shard_count=self.shard_count, + transformer_block_count=self.transformer_block_count, + attn_head_count=self.attn_head_count, + block_seq_stride=self.block_seq_stride, + attn_head_dim=self.attn_head_dim, + cache_partition_count=self.cache_partition_count, + dtype=self.dtype, + ) + + def make_unsharded_and_sharded_equal_cache_states( + self, + ) -> Tuple[List[torch.Tensor], List[SplitPrimitiveTensor]]: + cache_state = self.cache.allocate(self.page_count) + cache_state[0] = torch.rand_like(cache_state[0]) + sharded_cache_state = self.sharded_cache.shard_state(deepcopy(cache_state)) + self.assert_equal_unsharded_and_sharded_cache_states( + cache_state, sharded_cache_state + ) + return cache_state, sharded_cache_state + + def assert_equal_unsharded_and_sharded_cache_states( + self, + cache_state: List[torch.Tensor], + sharded_cache_state: List[SplitPrimitiveTensor], + ): + sharded_state_as_unsharded = ops.unshard( + self.sharded_cache.unflatten_page_table(sharded_cache_state) + ).flatten(start_dim=1) + assert ops.equal( + cache_state[0], + sharded_state_as_unsharded, + ) + + def testAllocate(self): + cache_state = self.cache.allocate(self.page_count) + sharded_cache_state = self.sharded_cache.allocate(self.page_count) assert len(cache_state) == 1 assert len(sharded_cache_state) == 1 assert iterables_equal(cache_state[0].shape, sharded_cache_state[0].shape) assert sharded_cache_state[0].shard_dim == 1 - assert sharded_cache_state[0].shard_count == shard_count + assert sharded_cache_state[0].shard_count == self.shard_count + + def testUnflattenPageTable(self): + cache_state = self.cache.allocate(self.page_count) + sharded_cache_state = self.sharded_cache.allocate(self.page_count) - # Test unflatten_page_table. - unflattened_cache_state = cache.unflatten_page_table(cache_state) - sharded_unflattened_cache_state = sharded_cache.unflatten_page_table( + unflattened_cache_state = self.cache.unflatten_page_table(cache_state) + sharded_unflattened_cache_state = self.sharded_cache.unflatten_page_table( sharded_cache_state ) assert iterables_equal( unflattened_cache_state.shape, sharded_unflattened_cache_state.shape ) assert sharded_unflattened_cache_state.shard_dim == 4 - assert sharded_unflattened_cache_state.shard_count == shard_count - assert sharded_unflattened_cache_state.shape[0] == page_count + assert sharded_unflattened_cache_state.shard_count == self.shard_count + assert sharded_unflattened_cache_state.shape[0] == self.page_count - # Make the sharded cache state have the same elements as the unsharded. - cache_state[0] = torch.rand_like(cache_state[0]) - sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state)) - assert ops.equal( - cache_state[0], - ops.unshard( - sharded_cache.unflatten_page_table(sharded_cache_state) - ).flatten(start_dim=1), - ) + def testRead(self): + ( + cache_state, + sharded_cache_state, + ) = self.make_unsharded_and_sharded_equal_cache_states() - # Test reading. read_into_partitions_snapshot = [ torch.rand( - batch_size, - block_seq_len * block_seq_stride, - attn_head_count, - attn_head_dim, + self.batch_size, + self.block_seq_len * self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, ) - for _ in range(cache_partition_count) + for _ in range(self.cache_partition_count) ] read_into_partitions = deepcopy(read_into_partitions_snapshot) transformer_block_index = 1 page_ids = torch.randint( - low=0, high=page_count, size=[batch_size, block_seq_len] - ).reshape([batch_size, block_seq_len]) - cache.read( + low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len] + ).reshape([self.batch_size, self.block_seq_len]) + self.cache.read( state=cache_state, read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, @@ -103,12 +126,12 @@ def testSmallCache(self): ) sharded_read_into_partitions = deepcopy( [ - ops.reshard_split(t, dim=2, count=shard_count) + ops.reshard_split(t, dim=2, count=self.shard_count) for t in read_into_partitions_snapshot ] ) - sharded_page_ids = ops.replicate(page_ids, count=shard_count) - sharded_cache.read( + sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + self.sharded_cache.read( state=sharded_cache_state, read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, @@ -118,3 +141,94 @@ def testSmallCache(self): read_into_partitions, sharded_read_into_partitions ): assert ops.equal(unsharded, ops.unshard(sharded)) + + def testWriteTimestep(self): + ( + cache_state, + sharded_cache_state, + ) = self.make_unsharded_and_sharded_equal_cache_states() + + cache_partitions = [ + torch.rand( + self.batch_size, + self.block_seq_len * self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ) + for _ in range(self.cache_partition_count) + ] + transformer_block_index = 1 + seq_positions = torch.randint( + low=0, high=self.max_seq_len, size=[self.batch_size] + ) + page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape( + [self.batch_size, self.block_seq_len] + ) + self.cache.write_timestep( + state=cache_state, + cache_partitions=cache_partitions, + transformer_block_index=transformer_block_index, + seq_positions=seq_positions, + page_ids=page_ids, + ) + sharded_cache_partitions = deepcopy( + [ + ops.reshard_split(t, dim=2, count=self.shard_count) + for t in cache_partitions + ] + ) + sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count) + sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + self.sharded_cache.write_timestep( + state=sharded_cache_state, + cache_partitions=sharded_cache_partitions, + transformer_block_index=transformer_block_index, + seq_positions=sharded_seq_positions, + page_ids=sharded_page_ids, + ) + self.assert_equal_unsharded_and_sharded_cache_states( + cache_state, sharded_cache_state + ) + + def testWrite(self): + ( + cache_state, + sharded_cache_state, + ) = self.make_unsharded_and_sharded_equal_cache_states() + + cache_partitions = [ + torch.rand( + self.batch_size, + self.block_seq_len * self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ) + for _ in range(self.cache_partition_count) + ] + transformer_block_index = 1 + assert self.batch_size * self.block_seq_len <= self.page_count + page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape( + [self.batch_size, self.block_seq_len] + ) + self.cache.write( + state=cache_state, + cache_partitions=cache_partitions, + transformer_block_index=transformer_block_index, + page_ids=page_ids, + ) + sharded_cache_partitions = deepcopy( + [ + ops.reshard_split(t, dim=2, count=self.shard_count) + for t in cache_partitions + ] + ) + sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) + self.sharded_cache.write( + state=sharded_cache_state, + cache_partitions=sharded_cache_partitions, + transformer_block_index=transformer_block_index, + page_ids=sharded_page_ids, + ) + self.assert_equal_unsharded_and_sharded_cache_states( + cache_state, sharded_cache_state + )