Skip to content

Commit

Permalink
In sharded paged KV cache test test also write and write_timestep
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Sep 25, 2024
1 parent fb04099 commit fdb35ab
Showing 1 changed file with 178 additions and 64 deletions.
242 changes: 178 additions & 64 deletions sharktank/tests/layers/sharded_paged_kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,106 +9,129 @@
import torch
from sharktank.utils import iterables_equal
from copy import deepcopy
from typing import List, Tuple
from sharktank import ops
from sharktank.types import SplitPrimitiveTensor


class ShardedPagedKVCacheTest(unittest.TestCase):
"""Verify that the sharded paged KV cache behaves as the unsharded variant."""

def setUp(self):
torch.manual_seed(12345)
torch.set_default_dtype(torch.float32)

def testSmallCache(self):
"""Verify that the sharded paged KV cache behaves as the unsharded variant."""
shard_count = 3
transformer_block_count = 5
attn_head_count = shard_count * 7
block_seq_stride = 19
attn_head_dim = 17
cache_partition_count = 2
page_count = 23
dtype = torch.float32
batch_size = 11
block_seq_len = 13
max_seq_len = block_seq_len * block_seq_stride

cache = PagedKVCache(
transformer_block_count=transformer_block_count,
attn_head_count=attn_head_count,
block_seq_stride=block_seq_stride,
attn_head_dim=attn_head_dim,
cache_partition_count=cache_partition_count,
dtype=dtype,
)
sharded_cache = PagedKVCache(
shard_count=shard_count,
transformer_block_count=transformer_block_count,
attn_head_count=attn_head_count,
block_seq_stride=block_seq_stride,
attn_head_dim=attn_head_dim,
cache_partition_count=cache_partition_count,
dtype=dtype,
)

# Test allocate.
cache_state = cache.allocate(page_count)
sharded_cache_state = sharded_cache.allocate(page_count)
self.dtype = torch.float32
torch.set_default_dtype(self.dtype)
self.shard_count = 3
self.transformer_block_count = 5
self.attn_head_count = self.shard_count * 7
self.block_seq_stride = 19
self.attn_head_dim = 17
self.cache_partition_count = 2
self.page_count = 23
self.batch_size = 11
self.block_seq_len = 2
self.max_seq_len = self.block_seq_len * self.block_seq_stride

self.cache = PagedKVCache(
transformer_block_count=self.transformer_block_count,
attn_head_count=self.attn_head_count,
block_seq_stride=self.block_seq_stride,
attn_head_dim=self.attn_head_dim,
cache_partition_count=self.cache_partition_count,
dtype=self.dtype,
)
self.sharded_cache = PagedKVCache(
shard_count=self.shard_count,
transformer_block_count=self.transformer_block_count,
attn_head_count=self.attn_head_count,
block_seq_stride=self.block_seq_stride,
attn_head_dim=self.attn_head_dim,
cache_partition_count=self.cache_partition_count,
dtype=self.dtype,
)

def make_unsharded_and_sharded_equal_cache_states(
self,
) -> Tuple[List[torch.Tensor], List[SplitPrimitiveTensor]]:
cache_state = self.cache.allocate(self.page_count)
cache_state[0] = torch.rand_like(cache_state[0])
sharded_cache_state = self.sharded_cache.shard_state(deepcopy(cache_state))
self.assert_equal_unsharded_and_sharded_cache_states(
cache_state, sharded_cache_state
)
return cache_state, sharded_cache_state

def assert_equal_unsharded_and_sharded_cache_states(
self,
cache_state: List[torch.Tensor],
sharded_cache_state: List[SplitPrimitiveTensor],
):
sharded_state_as_unsharded = ops.unshard(
self.sharded_cache.unflatten_page_table(sharded_cache_state)
).flatten(start_dim=1)
assert ops.equal(
cache_state[0],
sharded_state_as_unsharded,
)

def testAllocate(self):
cache_state = self.cache.allocate(self.page_count)
sharded_cache_state = self.sharded_cache.allocate(self.page_count)
assert len(cache_state) == 1
assert len(sharded_cache_state) == 1
assert iterables_equal(cache_state[0].shape, sharded_cache_state[0].shape)
assert sharded_cache_state[0].shard_dim == 1
assert sharded_cache_state[0].shard_count == shard_count
assert sharded_cache_state[0].shard_count == self.shard_count

def testUnflattenPageTable(self):
cache_state = self.cache.allocate(self.page_count)
sharded_cache_state = self.sharded_cache.allocate(self.page_count)

# Test unflatten_page_table.
unflattened_cache_state = cache.unflatten_page_table(cache_state)
sharded_unflattened_cache_state = sharded_cache.unflatten_page_table(
unflattened_cache_state = self.cache.unflatten_page_table(cache_state)
sharded_unflattened_cache_state = self.sharded_cache.unflatten_page_table(
sharded_cache_state
)
assert iterables_equal(
unflattened_cache_state.shape, sharded_unflattened_cache_state.shape
)
assert sharded_unflattened_cache_state.shard_dim == 4
assert sharded_unflattened_cache_state.shard_count == shard_count
assert sharded_unflattened_cache_state.shape[0] == page_count
assert sharded_unflattened_cache_state.shard_count == self.shard_count
assert sharded_unflattened_cache_state.shape[0] == self.page_count

# Make the sharded cache state have the same elements as the unsharded.
cache_state[0] = torch.rand_like(cache_state[0])
sharded_cache_state = sharded_cache.shard_state(deepcopy(cache_state))
assert ops.equal(
cache_state[0],
ops.unshard(
sharded_cache.unflatten_page_table(sharded_cache_state)
).flatten(start_dim=1),
)
def testRead(self):
(
cache_state,
sharded_cache_state,
) = self.make_unsharded_and_sharded_equal_cache_states()

# Test reading.
read_into_partitions_snapshot = [
torch.rand(
batch_size,
block_seq_len * block_seq_stride,
attn_head_count,
attn_head_dim,
self.batch_size,
self.block_seq_len * self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(cache_partition_count)
for _ in range(self.cache_partition_count)
]
read_into_partitions = deepcopy(read_into_partitions_snapshot)
transformer_block_index = 1
page_ids = torch.randint(
low=0, high=page_count, size=[batch_size, block_seq_len]
).reshape([batch_size, block_seq_len])
cache.read(
low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len]
).reshape([self.batch_size, self.block_seq_len])
self.cache.read(
state=cache_state,
read_into_partitions=read_into_partitions,
transformer_block_index=transformer_block_index,
page_ids=page_ids,
)
sharded_read_into_partitions = deepcopy(
[
ops.reshard_split(t, dim=2, count=shard_count)
ops.reshard_split(t, dim=2, count=self.shard_count)
for t in read_into_partitions_snapshot
]
)
sharded_page_ids = ops.replicate(page_ids, count=shard_count)
sharded_cache.read(
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.read(
state=sharded_cache_state,
read_into_partitions=sharded_read_into_partitions,
transformer_block_index=transformer_block_index,
Expand All @@ -118,3 +141,94 @@ def testSmallCache(self):
read_into_partitions, sharded_read_into_partitions
):
assert ops.equal(unsharded, ops.unshard(sharded))

def testWriteTimestep(self):
(
cache_state,
sharded_cache_state,
) = self.make_unsharded_and_sharded_equal_cache_states()

cache_partitions = [
torch.rand(
self.batch_size,
self.block_seq_len * self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(self.cache_partition_count)
]
transformer_block_index = 1
seq_positions = torch.randint(
low=0, high=self.max_seq_len, size=[self.batch_size]
)
page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape(
[self.batch_size, self.block_seq_len]
)
self.cache.write_timestep(
state=cache_state,
cache_partitions=cache_partitions,
transformer_block_index=transformer_block_index,
seq_positions=seq_positions,
page_ids=page_ids,
)
sharded_cache_partitions = deepcopy(
[
ops.reshard_split(t, dim=2, count=self.shard_count)
for t in cache_partitions
]
)
sharded_seq_positions = ops.replicate(seq_positions, count=self.shard_count)
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.write_timestep(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
transformer_block_index=transformer_block_index,
seq_positions=sharded_seq_positions,
page_ids=sharded_page_ids,
)
self.assert_equal_unsharded_and_sharded_cache_states(
cache_state, sharded_cache_state
)

def testWrite(self):
(
cache_state,
sharded_cache_state,
) = self.make_unsharded_and_sharded_equal_cache_states()

cache_partitions = [
torch.rand(
self.batch_size,
self.block_seq_len * self.block_seq_stride,
self.attn_head_count,
self.attn_head_dim,
)
for _ in range(self.cache_partition_count)
]
transformer_block_index = 1
assert self.batch_size * self.block_seq_len <= self.page_count
page_ids = torch.randperm(self.batch_size * self.block_seq_len).reshape(
[self.batch_size, self.block_seq_len]
)
self.cache.write(
state=cache_state,
cache_partitions=cache_partitions,
transformer_block_index=transformer_block_index,
page_ids=page_ids,
)
sharded_cache_partitions = deepcopy(
[
ops.reshard_split(t, dim=2, count=self.shard_count)
for t in cache_partitions
]
)
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.write(
state=sharded_cache_state,
cache_partitions=sharded_cache_partitions,
transformer_block_index=transformer_block_index,
page_ids=sharded_page_ids,
)
self.assert_equal_unsharded_and_sharded_cache_states(
cache_state, sharded_cache_state
)

0 comments on commit fdb35ab

Please sign in to comment.