diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py new file mode 100644 index 000000000..aa0cdf961 --- /dev/null +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -0,0 +1,417 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Export support for the PagedLLMV1 protocol of models.""" + +import json +import torch + +from typing import Optional + +import torch.nn.functional as F + +from shark_turbine.aot import * + +from sharktank.layers import * +from sharktank.types import * + +from sharktank.models.llama.testing import * +from sharktank.layers import causal_llm + +from sharktank.utils.create_cache import * + +# TODO: Should be using a base class with the protocol supported. +from ..models.llama.llama import LlamaModelConfig, PagedLlamaAttentionBlock + + +def paged_attention( + attention_block: PagedLlamaAttentionBlock, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + is_causal: bool, + seq_block_ids: torch.Tensor, + start_positions: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_state: list[torch.Tensor] = None, + xk_temp: Optional[torch.Tensor] = None, + xv_temp: Optional[torch.Tensor] = None, +): + + bs, batch_seq_len, _, _ = xq.shape + + # Full sequence length. + kv_seq_len = seq_block_ids.shape[1] * attention_block.cache.block_seq_stride + + if attention_block.cache.is_paged: + xk, xv = attention_block.transact_cache_paged( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + elif attention_block.cache.is_direct: + xk, xv = attention_block.transact_cache_direct( + xk_cache_update=xk, + xv_cache_update=xv, + start_positions=start_positions, + kv_seq_len=kv_seq_len, + cache_state=cache_state, + ) + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(cache)}") + + # Expand kv heads for GQA. + gqa_n_rep = attention_block.head_count // attention_block.head_count_kv + assert gqa_n_rep > 0 + if gqa_n_rep > 1: + + def repeat_kv(x: torch.Tensor) -> torch.Tensor: + bs, slen, n_kv_heads, head_dim = x.shape + return ( + x.unsqueeze(-2) + .expand(bs, slen, n_kv_heads, gqa_n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * gqa_n_rep, head_dim) + ) + + xk = repeat_kv(xk) + xv = repeat_kv(xv) + + # Transpose into [bs, heads, sl, dim] + xq = xq.transpose(1, 2) + keys = xk.transpose(1, 2) + values = xv.transpose(1, 2) + attention_mask = None + attn_output = F.scaled_dot_product_attention( + xq, keys, values, attn_mask=attention_mask, is_causal=is_causal + ) + attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) + return attn_output + + +def run_llama( + model: PagedLlamaAttentionBlock, + config: LlamaModelConfig, + phase: str, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + # [1, 1, batch_seq_len, batch_seq_len] + attention_mask: torch.Tensor, + # [bs, batch_seq_len // block_seq_stride] + seq_block_ids: torch.Tensor, + cache_state: list[torch.Tensor], + # [bs] of starting positions + start_positions: Optional[torch.Tensor] = None, +): + + if phase == "decode": + bs, _, _, _ = xq.shape + + # Allocate per-block temporary K/V tensors. These temporaries hold + # one block's K/V state for the maximum context length. + xk_temp = torch.empty( + [ + bs, + config.hp.context_length, + config.hp.attention_head_count_kv, + config.hp.attn_head_dim, + ], + dtype=config.activation_dtype, + device=config.device, + ) + xv_temp = torch.empty( + [ + bs, + config.hp.context_length, + config.hp.attention_head_count_kv, + config.hp.attn_head_dim, + ], + dtype=config.activation_dtype, + device=config.device, + ) + elif phase == "prefill": + xk_temp = None + xv_temp = None + else: + raise ValueError("'phase' argument needs to be either 'prefill' or 'decode'") + + h = paged_attention( + model, + xq=xq, + xk=xk, + xv=xv, + is_causal=config.is_causal, + start_positions=start_positions, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + + return h + + +def main(): + from ..utils import cli + + parser = cli.create_parser() + # cli.add_input_dataset_options(parser) + parser.add_argument( + "--output-mlir", + help="Output file path for exported MLIR file", + default="/home/aramalin/sharktank/artifacts/paged_llama.mlir", + ) + parser.add_argument( + "--output-config", + help="Output file path for exported config file", + default="/home/aramalin/sharktank/artifacts/paged_llama.json", + ) + parser.add_argument( + "--bs", + help="Comma-separated batch size(s) to generate, e.g. `4` or `2,4`", + type=lambda arg: [int(bs) for bs in arg.split(",")], + default="4", + ) + parser.add_argument( + "--verbose", + help="Include verbose logging", + action="store_true", + ) + + parser.add_argument( + "--is-causal", + help="Enable Causal attention", + action="store_true", + ) + + args = cli.parse(parser) + + # dataset = cli.get_input_dataset(args) + # hp = configs.LlamaHParams.from_gguf_props(dataset.properties) + + hp = configs.LlamaHParams( + context_length=4096, + embedding_length=4096, + block_count=1, + feed_forward_length=11008, + attn_head_dim=128, + rope_dimension_count=128, + attention_head_count=32, + attention_layer_norm_rms_epsilon=9.999999747378752e-06, + attention_head_count_kv=32, + model_arch="llama", + ) + + llama_config = LlamaModelConfig(hp) + llama_config.kv_cache_type = "direct" if args.bs == [1] else "paged" + llama_config.bs = args.bs + llama_config.is_causal = args.is_causal + + attention_block_theta = make_attention_block_theta( + feature_dim=llama_config.hp.attention_head_count + * llama_config.hp.attn_head_dim, + ffn_dim=llama_config.hp.feed_forward_length, + dtype=llama_config.attention_dtype, + ) + + causal_model = causal_llm.BaseCausalLMModel( + attention_block_theta, context_length=llama_config.hp.context_length + ) + + model = PagedLlamaAttentionBlock( + theta=attention_block_theta, + block_index=0, + cache=create_kv_cache(llama_config), + head_count=llama_config.hp.attention_head_count, + head_dim=llama_config.hp.attn_head_dim, + head_count_kv=llama_config.hp.attention_head_count_kv, + rms_epsilon=llama_config.hp.attention_layer_norm_rms_epsilon, + ) + + def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): + return { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": hp.context_length, + "attn_head_count": hp.attention_head_count, + "attn_head_dim": hp.attn_head_dim, + "prefill_batch_sizes": prefill_bs, + "decode_batch_sizes": decode_bs, + "transformer_block_count": hp.block_count, + "block_seq_stride": llama_config.block_seq_stride, + } + + fxb = FxProgramsBuilder(model) + + def generate_batch_prefill(bs: int): + tokens = torch.empty(bs, 64, dtype=torch.int64) + seq_lens = torch.empty(bs, dtype=torch.int64) + seq_block_ids = torch.empty(bs, 4, dtype=torch.int64) + block_dim = torch.export.Dim( + "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + ) + sl_dim = llama_config.block_seq_stride * block_dim + + if llama_config.kv_cache_type == "paged": + cache_state = model.cache.allocate( + page_count=hp.context_length // llama_config.block_seq_stride + ) + page_dim = torch.export.Dim("page") + cache_state_dynamic_shapes = [{0: page_dim}] + elif llama_config.kv_cache_type == "direct": + cache_state = model.cache.allocate(bs=1) + # Direct cache dimensions: + # 2 * transformer_block_count of... + # [bs, seq_length, attn_head_count, attn_head_dim] + cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") + + dynamic_shapes = { + "tokens": {1: sl_dim}, + "seq_lens": {}, + "seq_block_ids": {1: block_dim}, + "cache_state": cache_state_dynamic_shapes, + } + + q = torch.zeros((bs, 64, 32, 128), dtype=torch.float16) + k = torch.zeros((bs, 64, 32, 128), dtype=torch.float16) + v = torch.zeros((bs, 64, 32, 128), dtype=torch.float16) + + print(f"Exporting prefill_bs{bs}") + example_args = (q, k, v, seq_lens, seq_block_ids, cache_state) + + @fxb.export_program( + name=f"prefill_bs{bs}", + args=example_args, + ) + def _(model, q, k, v, seq_lens, seq_block_ids, cache_state): + + if llama_config.is_causal: + attention_mask = None + else: + sl = tokens.shape[1] + input_mask = causal_model.input_mask(seq_lens, sl) + attention_mask = causal_model.attention_mask(input_mask) + + h = run_llama( + model=model, + config=llama_config, + phase="prefill", + xq=q, + xk=k, + xv=v, + attention_mask=attention_mask, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + return h + + def generate_batch_decode(bs: int): + tokens = torch.ones(bs, 1, dtype=torch.int64) + seq_lens = torch.ones(bs, dtype=torch.int64) + start_positions = torch.ones(bs, dtype=torch.int64) + seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64) + block_dim = torch.export.Dim( + "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + ) + + if llama_config.kv_cache_type == "paged": + cache_state = model.cache.allocate( + page_count=hp.context_length // llama_config.block_seq_stride + ) + page_dim = torch.export.Dim("page") + cache_state_dynamic_shapes = [{0: page_dim}] + elif llama_config.kv_cache_type == "direct": + cache_state = model.cache.allocate(bs=1) + # Direct cache dimensions: + # 2 * transformer_block_count of... + # [bs, seq_length, attn_head_count, attn_head_dim] + cache_state_dynamic_shapes = (2 * hp.block_count) * [{}] + else: + raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}") + + dynamic_shapes = { + "tokens": {}, + "seq_lens": {}, + "start_positions": {}, + "seq_block_ids": {1: block_dim}, + "cache_state": cache_state_dynamic_shapes, + } + + q = torch.zeros((bs, 1, 32, 128), dtype=torch.float16) + k = torch.zeros((bs, 1, 32, 128), dtype=torch.float16) + v = torch.zeros((bs, 1, 32, 128), dtype=torch.float16) + + print(f"Exporting decode_bs{bs}") + example_args = (q, k, v, seq_lens, start_positions, seq_block_ids, cache_state) + + @fxb.export_program( + name=f"decode_bs{bs}", + args=example_args, + ) + def _( + model, + q, + k, + v, + seq_lens, + start_positions, + seq_block_ids, + cache_state, + ): + + if llama_config.is_causal: + attention_mask = None + else: + input_mask = causal_model.input_mask( + seq_lens, seq_block_ids.shape[1] * model.cache.block_seq_stride + ) + attention_mask = causal_model.decode_attention_mask(input_mask) + + h = run_llama( + model=model, + config=llama_config, + phase="decode", + xq=q, + xk=k, + xv=v, + attention_mask=attention_mask, + start_positions=start_positions, + seq_block_ids=seq_block_ids, + cache_state=cache_state, + ) + + return h + + bsizes = [] + for bs in llama_config.bs: + generate_batch_prefill(bs) + generate_batch_decode(bs) + bsizes.append(bs) + + if args.verbose: + for name, ep in fxb.programs.items(): + print(f"EXPORT {name}:\n{ep}") + + config = generate_params_json(hp, bsizes, bsizes) + print("GENERATED!") + + print("Exporting") + output = export(fxb) + print(f"Saving to '{args.output_mlir}'") + output.save_mlir(args.output_mlir) + json.dump(config, open(args.output_config, "w")) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index 4cd21ae8e..309fe322a 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -363,21 +363,33 @@ def write_timestep( page_table = self.unflatten_page_table(state) # 6D bs, *_ = seq_positions.shape assert len(cache_partitions) == self.cache_partition_count - for i in range(bs): - position = seq_positions[i] - # TODO: Let's clamp to the allowable range so that we don't need - # an assert. - page_id = page_ids[i, :].index_select(0, position // self.block_seq_stride) - page_offset = position % self.block_seq_stride - for partition_index in range(self.cache_partition_count): - cache_partition = cache_partitions[partition_index] - indices = ( - page_id, - torch.tensor([transformer_block_index], device=device), - torch.tensor([partition_index], device=device), - page_offset.unsqueeze(0), - ) - page_table.index_put_(indices=indices, values=cache_partition[i, 0]) + + partition_count = len(cache_partitions) + + # [bs, partitions, atten_head_count, attn_head_dim] + cache_partitions = ops.cat(cache_partitions, dim=1) + + # [bs, 1] + page_index = seq_positions // self.block_seq_stride + + page_id = ops.gather(page_ids, dim=1, index=page_index.unsqueeze(1)) + page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1) + + # [1, partitions] + partitions = torch.arange(0, self.cache_partition_count).unsqueeze(0) + + # [bs, partitions] + page_id = page_id.repeat(1, partition_count) + transformer_block = torch.full( + (bs, partition_count), transformer_block_index, device=device + ) + page_offset = page_offset.repeat(1, partition_count) + partitions = partitions.repeat(bs, 1) + + indices = (page_id, transformer_block, partitions, page_offset) + page_table.index_put_(indices=indices, values=cache_partitions) + + return def write( self, @@ -418,23 +430,18 @@ def write( transformer_block_index * transformer_block_stride ) - def write_cache_partition( - index: int, part: Union[torch.Tensor, SplitPrimitiveTensor] - ): - part_block_view = part.reshape(blocked_shape) + part_block_views = [] + subblock_ids_kv = [] + for index, partition in enumerate(cache_partitions): + part_block_view = partition.reshape(blocked_shape).flatten(0, 1) + part_block_views.append(part_block_view) + subblock_ids = ( (base_subblock_ids + index) if index > 0 else base_subblock_ids - ) - # TODO: Potentially clamp all page 0 indices to the mask value. - # Or even better, require that the ids are replicated such that access is - # legal. - # Now for each of the k/v attn_block_ids, which have been adjusted to - # index into the sub-pages, we flatten to do a linear index_select - # copy of the sub-blocks by collapsing the first two dims so we have - # a linear list. - subblock_table.index_copy_( - 0, subblock_ids.flatten(0, 1), part_block_view.flatten(0, 1) - ) + ).flatten(0, 1) + subblock_ids_kv.append(subblock_ids) - for index, partition in enumerate(cache_partitions): - write_cache_partition(index, partition) + subblock_ids = ops.cat(subblock_ids_kv) + part_block_view = ops.cat(part_block_views, dim=0) + + subblock_table.index_copy_(0, subblock_ids, part_block_view) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index a2fcd2813..ed6e6c730 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -157,6 +157,15 @@ def flatten_default( return torch.flatten(unbox_tensor(input), start_dim, end_dim) +@gather.override(Tensor, Tensor) +def gather_default( + input: Union[Tensor, PrimitiveTensor], + dim: int, + index: Union[Tensor, PrimitiveTensor], +) -> Tensor: + return torch.gather(unbox_tensor(input), dim, unbox_tensor(index)) + + @get_index.override(AllOfType(Tensor, PrimitiveTensor)) def get_index_default(tensor, key): return unbox_tensor(tensor).__get_item__(key) @@ -333,6 +342,11 @@ def module_register_buffer_default( return module.register_buffer(name, unbox_tensor(tensor)) +@repeat.override(Tensor) +def repeat_default(input: Union[Tensor, PrimitiveTensor], *sizes: List[int]) -> Tensor: + return unbox_tensor(input).repeat(*sizes) + + @reshape.override(Tensor) def reshape_default(input: Union[PrimitiveTensor, Tensor], shape: List[int]) -> Tensor: return torch.reshape(unbox_tensor(input), shape) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 69cce73b8..a667669f4 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -61,8 +61,18 @@ def all_reduce_split_or_unreduced( return ReplicatedTensor(ts=shards) +@cat.override(AllOfType(ReplicatedTensor)) +def cat_replicated(tensors: Sequence[ReplicatedTensor], dim: int) -> ReplicatedTensor: + assert len(tensors) > 0 + shard_count = tensors[0].shard_count + assert all([t.shard_count == shard_count for t in tensors]) + + shards = [cat(shards, dim) for shards in zip(*[t.shards for t in tensors])] + return ReplicatedTensor(ts=shards) + + @cat.override(AllOfType(SplitPrimitiveTensor)) -def cat_sharded( +def cat_split( tensors: Sequence[SplitPrimitiveTensor], dim: int ) -> SplitPrimitiveTensor: assert len(tensors) > 0 @@ -456,6 +466,18 @@ def flatten_split( return SplitPrimitiveTensor(ts=shards, shard_dim=shard_dim) +@gather.override(ReplicatedTensor, ReplicatedTensor) +def gather_replicated( + input: ReplicatedTensor, dim: int, index: ReplicatedTensor +) -> Tensor: + assert input.shard_count == index.shard_count + shards = [ + gather(input_shard, dim, index_shard) + for input_shard, index_shard in zip(input.shards, index.shards) + ] + return ReplicatedTensor(ts=shards) + + @group_norm_affine.override( SplitPrimitiveTensor, SplitPrimitiveTensor, SplitPrimitiveTensor ) @@ -802,6 +824,12 @@ def permute_replicated(tensor: ReplicatedTensor, dims: List[int]): return ReplicatedTensor(ts=permuted_shards) +@repeat.override(ReplicatedTensor) +def repeat_replicated(input: ReplicatedTensor, *sizes: List[int]) -> ReplicatedTensor: + shards = [repeat(shard, *sizes) for shard in input.shards] + return ReplicatedTensor(ts=shards) + + @replicate.override(ReplicatedTensor) def replicate_replicated(input: ReplicatedTensor, *, count: int) -> ReplicatedTensor: if input.shard_count != count: diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 59f7672c7..89d4309ee 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -27,6 +27,7 @@ "equal", "expand", "flatten", + "gather", "get_index", "gemm", "group_norm_affine", @@ -41,6 +42,7 @@ "module_register_buffer", "permute", "rms_norm", + "repeat", "replicate", "reshape", "reshard", @@ -348,6 +350,28 @@ def _flatten_trampoline( d.fail(dispatch_args) +@overridable +def gather(input: AnyTensor, dim: int, index: AnyTensor) -> AnyTensor: + """See torch.gather""" + ... + + +@gather.trampoline +def _gather_trampoline( + d: SignatureDispatcher, input: AnyTensor, dim: int, index: AnyTensor +) -> AnyTensor: + dispatch_args = ( + input, + index, + ) + for override in d.find_overrides(dispatch_args): + result = override(input, dim, index) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def gemm( a: AnyTensor, @@ -718,6 +742,25 @@ def _rms_norm_trampoline( d.fail(tensors) +@overridable +def repeat(input: AnyTensor, *sizes: List[int]) -> AnyTensor: + """See torch.Tensor.repeat""" + ... + + +@repeat.trampoline +def _repeat_trampoline( + d: SignatureDispatcher, input: AnyTensor, *sizes: List[int] +) -> AnyTensor: + dispatch_args = (input,) + for override in d.find_overrides(dispatch_args): + result = override(input, *sizes) + if result is not NotImplemented: + return override, result + else: + d.fail(dispatch_args) + + @overridable def replicate(input: AnyTensor, count: int) -> ShardedTensor: """Replicate across devices. diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 534538f7e..93aac9e34 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -343,6 +343,11 @@ def pow(self, exponent: Union["AnyTensor", Number]) -> "AnyTensor": return elementwise(torch.pow, self, exponent) + def repeat(self, *sizes: List[int]) -> "AnyTensor": + from ..ops import repeat + + return repeat(self, *sizes) + def reshape(self, *args: Union[List[List[int]], List[int]]) -> "AnyTensor": from ..ops import reshape diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index 621f3ac13..d58874f25 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -151,7 +151,7 @@ def testWriteTimestep(self): cache_partitions = [ torch.rand( self.batch_size, - self.block_seq_len * self.block_seq_stride, + 1, self.attn_head_count, self.attn_head_dim, ) diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index 6d6769731..4638df312 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -154,7 +154,7 @@ def setUp(self): vocab_size=self.vocabulary_size, ) self.prefill_seq_lens = torch.tensor( - [14, 9, self.block_seq_stride - 1], dtype=torch.int32 + [14, 9, self.block_seq_stride - 1], dtype=torch.int64 ) def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]: