Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sharktank] Export Attention IRs for LLMs #175

Merged
merged 25 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1b4444e
Add Direct cache attention export
archana-ramalingam Aug 22, 2024
7869c91
Add direct cache attention
archana-ramalingam Aug 23, 2024
8dc2dbd
Update paged and direct cache attention exports
archana-ramalingam Aug 28, 2024
37b28f7
Update attention export script
archana-ramalingam Sep 4, 2024
2f5dd66
Fuse decode scatters
archana-ramalingam Sep 6, 2024
42d9fb8
Add scatter fusion for prefill and decode
archana-ramalingam Sep 9, 2024
48ee669
Cleanup debug statements
archana-ramalingam Sep 10, 2024
4418ea2
Cleanup
archana-ramalingam Sep 10, 2024
6fcded5
Rename
archana-ramalingam Sep 10, 2024
9ccf41c
Add pre-commit hooks updated files
archana-ramalingam Sep 10, 2024
affebd2
Update .pre-commit-config.yaml
archana-ramalingam Sep 10, 2024
ca73207
Merge branch 'main' into attention_microbenchmark
archana-ramalingam Sep 10, 2024
79d4d2c
Delete unpaged attention export script
archana-ramalingam Sep 10, 2024
9c51aa7
Move attention export script to export_layer folder
archana-ramalingam Sep 11, 2024
28f9515
Update conflicting MOE params
archana-ramalingam Sep 11, 2024
6ca1a7f
Set attention_mask to None
archana-ramalingam Sep 11, 2024
b0d9fe2
Merge branch 'main' into attention_microbenchmark
archana-ramalingam Sep 13, 2024
aba713d
Merge branch 'main' into attention_microbenchmark
archana-ramalingam Oct 2, 2024
0605ddf
Handle sharded case and fix sharded KV cache test
sogartar Oct 2, 2024
5375807
In ShardedLlamaTest cange seq_lens type to torch.int64
sogartar Oct 2, 2024
9dde0de
Remove unused cache_partitions_list
archana-ramalingam Oct 3, 2024
bbaaf1d
Add new line at end of file
archana-ramalingam Oct 3, 2024
ace92fc
Merge branch 'attention_microbenchmark' of https://github.com/nod-ai/…
archana-ramalingam Oct 3, 2024
8413de5
Add missing dim
archana-ramalingam Oct 3, 2024
3018608
Merge branch 'main' into attention_microbenchmark
archana-ramalingam Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
417 changes: 417 additions & 0 deletions sharktank/sharktank/export_layer/export_paged_attention.py

Large diffs are not rendered by default.

71 changes: 39 additions & 32 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,21 +363,33 @@ def write_timestep(
page_table = self.unflatten_page_table(state) # 6D
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count
for i in range(bs):
position = seq_positions[i]
# TODO: Let's clamp to the allowable range so that we don't need
# an assert.
page_id = page_ids[i, :].index_select(0, position // self.block_seq_stride)
page_offset = position % self.block_seq_stride
for partition_index in range(self.cache_partition_count):
cache_partition = cache_partitions[partition_index]
indices = (
page_id,
torch.tensor([transformer_block_index], device=device),
torch.tensor([partition_index], device=device),
page_offset.unsqueeze(0),
)
page_table.index_put_(indices=indices, values=cache_partition[i, 0])

partition_count = len(cache_partitions)

# [bs, partitions, atten_head_count, attn_head_dim]
cache_partitions = ops.cat(cache_partitions, dim=1)

# [bs, 1]
page_index = seq_positions // self.block_seq_stride

page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1))
page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)

# [1, partitions]
partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0)

# [bs, partitions]
page_id = page_id.repeat(1, partition_count)
transformer_block = torch.full(
(bs, partition_count), transformer_block_index, device=device
)
page_offset = page_offset.repeat(1, partition_count)
partitions = partitions.repeat(bs, 1)

indices = (page_id, transformer_block, partitions, page_offset)
page_table.index_put_(indices=indices, values=cache_partitions)

return

def write(
self,
Expand Down Expand Up @@ -418,23 +430,18 @@ def write(
transformer_block_index * transformer_block_stride
)

def write_cache_partition(
index: int, part: Union[torch.Tensor, SplitPrimitiveTensor]
):
part_block_view = part.reshape(blocked_shape)
part_block_views = []
subblock_ids_kv = []
for index, partition in enumerate(cache_partitions):
part_block_view = partition.reshape(blocked_shape).flatten(0, 1)
part_block_views.append(part_block_view)

subblock_ids = (
(base_subblock_ids + index) if index > 0 else base_subblock_ids
)
# TODO: Potentially clamp all page 0 indices to the mask value.
# Or even better, require that the ids are replicated such that access is
# legal.
# Now for each of the k/v attn_block_ids, which have been adjusted to
# index into the sub-pages, we flatten to do a linear index_select
# copy of the sub-blocks by collapsing the first two dims so we have
# a linear list.
subblock_table.index_copy_(
0, subblock_ids.flatten(0, 1), part_block_view.flatten(0, 1)
)
).flatten(0, 1)
subblock_ids_kv.append(subblock_ids)

for index, partition in enumerate(cache_partitions):
write_cache_partition(index, partition)
subblock_ids = ops.cat(subblock_ids_kv)
part_block_view = ops.cat(part_block_views, dim=0)

subblock_table.index_copy_(0, subblock_ids, part_block_view)
14 changes: 14 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ def flatten_default(
return torch.flatten(unbox_tensor(input), start_dim, end_dim)


@gather.override(Tensor, Tensor)
def gather_default(
input: Union[Tensor, PrimitiveTensor],
dim: int,
index: Union[Tensor, PrimitiveTensor],
) -> Tensor:
return torch.gather(unbox_tensor(input), dim, unbox_tensor(index))


@get_index.override(AllOfType(Tensor, PrimitiveTensor))
def get_index_default(tensor, key):
return unbox_tensor(tensor).__get_item__(key)
Expand Down Expand Up @@ -333,6 +342,11 @@ def module_register_buffer_default(
return module.register_buffer(name, unbox_tensor(tensor))


@repeat.override(Tensor)
def repeat_default(input: Union[Tensor, PrimitiveTensor], *sizes: List[int]) -> Tensor:
return unbox_tensor(input).repeat(*sizes)


@reshape.override(Tensor)
def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> Tensor:
return torch.reshape(unbox_tensor(input), shape)
Expand Down
30 changes: 29 additions & 1 deletion sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,18 @@ def all_reduce_split_or_unreduced(
return ReplicatedTensor(ts=shards)


@cat.override(AllOfType(ReplicatedTensor))
def cat_replicated(tensors: Sequence[ReplicatedTensor], dim: int) -> ReplicatedTensor:
assert len(tensors) > 0
shard_count = tensors[0].shard_count
assert all([t.shard_count == shard_count for t in tensors])

shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])]
return ReplicatedTensor(ts=shards)


@cat.override(AllOfType(SplitPrimitiveTensor))
def cat_sharded(
def cat_split(
tensors: Sequence[SplitPrimitiveTensor], dim: int
) -> SplitPrimitiveTensor:
assert len(tensors) > 0
Expand Down Expand Up @@ -456,6 +466,18 @@ def flatten_split(
return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim)


@gather.override(ReplicatedTensor, ReplicatedTensor)
def gather_replicated(
input: ReplicatedTensor, dim: int, index: ReplicatedTensor
) -> Tensor:
assert input.shard_count == index.shard_count
shards = [
gather(input_shard, dim, index_shard)
for input_shard, index_shard in zip(input.shards, index.shards)
]
return ReplicatedTensor(ts=shards)


@group_norm_affine.override(
SplitPrimitiveTensor, SplitPrimitiveTensor, SplitPrimitiveTensor
)
Expand Down Expand Up @@ -802,6 +824,12 @@ def permute_replicated(tensor: ReplicatedTensor, dims: List[int]):
return ReplicatedTensor(ts=permuted_shards)


@repeat.override(ReplicatedTensor)
def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedTensor:
shards = [repeat(shard, *sizes) for shard in input.shards]
return ReplicatedTensor(ts=shards)


@replicate.override(ReplicatedTensor)
def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor:
if input.shard_count != count:
Expand Down
43 changes: 43 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"equal",
"expand",
"flatten",
"gather",
"get_index",
"gemm",
"group_norm_affine",
Expand All @@ -41,6 +42,7 @@
"module_register_buffer",
"permute",
"rms_norm",
"repeat",
"replicate",
"reshape",
"reshard",
Expand Down Expand Up @@ -348,6 +350,28 @@ def _flatten_trampoline(
d.fail(dispatch_args)


@overridable
def gather(input: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor:
"""See torch.gather"""
...


@gather.trampoline
def _gather_trampoline(
d: SignatureDispatcher, input: AnyTensor, dim: int, index: AnyTensor
) -> AnyTensor:
dispatch_args = (
input,
index,
)
for override in d.find_overrides(dispatch_args):
result = override(input, dim, index)
if result is not NotImplemented:
return override, result
else:
d.fail(dispatch_args)


@overridable
def gemm(
a: AnyTensor,
Expand Down Expand Up @@ -718,6 +742,25 @@ def _rms_norm_trampoline(
d.fail(tensors)


@overridable
def repeat(input: AnyTensor, *sizes: List[int]) -> AnyTensor:
"""See torch.Tensor.repeat"""
...


@repeat.trampoline
def _repeat_trampoline(
d: SignatureDispatcher, input: AnyTensor, *sizes: List[int]
) -> AnyTensor:
dispatch_args = (input,)
for override in d.find_overrides(dispatch_args):
result = override(input, *sizes)
if result is not NotImplemented:
return override, result
else:
d.fail(dispatch_args)


@overridable
def replicate(input: AnyTensor, count: int) -> ShardedTensor:
"""Replicate across devices.
Expand Down
5 changes: 5 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ def pow(self, exponent: Union["AnyTensor", Number]) -> "AnyTensor":

return elementwise(torch.pow, self, exponent)

def repeat(self, *sizes: List[int]) -> "AnyTensor":
from ..ops import repeat

return repeat(self, *sizes)

def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor":
from ..ops import reshape

Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/layers/sharded_paged_kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def testWriteTimestep(self):
cache_partitions = [
torch.rand(
self.batch_size,
self.block_seq_len * self.block_seq_stride,
1,
archana-ramalingam marked this conversation as resolved.
Show resolved Hide resolved
self.attn_head_count,
self.attn_head_dim,
)
Expand Down
2 changes: 1 addition & 1 deletion sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def setUp(self):
vocab_size=self.vocabulary_size,
)
self.prefill_seq_lens = torch.tensor(
[14, 9, self.block_seq_stride - 1], dtype=torch.int32
[14, 9, self.block_seq_stride - 1], dtype=torch.int64
)

def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
Expand Down
Loading