diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py index f87fae774..edd20ecf3 100644 --- a/sharktank/sharktank/export_layer/export_paged_attention.py +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -241,7 +241,7 @@ def main(): head_dim=llama_config.hp.attn_head_dim, head_count_kv=llama_config.hp.attention_head_count_kv, rms_epsilon=llama_config.hp.attention_layer_norm_rms_epsilon, - attention_kernel=args.attention_kernel + attention_kernel=args.attention_kernel, ) def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 774504501..ee767cf3e 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -113,12 +113,13 @@ def forward( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride - print("all", - start_positions, - cache_state, - xk_temp, - xv_temp, - ) + print( + "all", + start_positions, + cache_state, + xk_temp, + xv_temp, + ) if self.cache.is_paged: xk, xv = self.transact_cache_paged( xk_cache_update=xk, diff --git a/sharktank/tests/layers/paged_llama_attention_block_test.py b/sharktank/tests/layers/paged_llama_attention_block_test.py index 7e859be90..6e737eccc 100644 --- a/sharktank/tests/layers/paged_llama_attention_block_test.py +++ b/sharktank/tests/layers/paged_llama_attention_block_test.py @@ -13,10 +13,15 @@ import torch from iree.turbine import aot -from sharktank.layers import PagedLlamaAttentionBlock, PagedKVCache, RotaryEmbeddingLayer +from sharktank.layers import ( + PagedLlamaAttentionBlock, + PagedKVCache, + RotaryEmbeddingLayer, +) from sharktank.layers.testing import make_llama_attention_block_theta from sharktank.types.tensors import DefaultPrimitiveTensor + class PagedLlamaAttentionBlockTest(unittest.TestCase): def setUp(self): torch.manual_seed(12345) @@ -38,7 +43,6 @@ def setUp(self): self.batch_size = 3 self.start_index = 0 - def testExportDecomposed(self): dtype = torch.float32 @@ -74,22 +78,32 @@ def testExportDecomposed(self): seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( self.batch_size, -1 ) - + embedding_module = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, max_seqlen=self.max_seqlen, rope_freq_base=self.rope_freq_base, ) - + class MyModule(torch.nn.Module): def forward(self, h, seq_block_ids, cache_state): - return attn.forward(h, seq_block_ids=seq_block_ids, embedding=embedding_module, start_index=0, cache_state=cache_state) - + return attn.forward( + h, + seq_block_ids=seq_block_ids, + embedding=embedding_module, + start_index=0, + cache_state=cache_state, + ) + mod = MyModule() - h = torch.rand([self.batch_size, self.max_seqlen, self.attention_head_count * self.attention_head_dim]) - mod.forward(h, - seq_block_ids, - cache_state) + h = torch.rand( + [ + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ] + ) + mod.forward(h, seq_block_ids, cache_state) ep = torch.export.export( mod, args=( @@ -104,8 +118,7 @@ def forward(self, h, seq_block_ids, cache_state): output.save_mlir("temp.mlir") self.assertNotIn("scaled_dot_product_attention", asm) - - def testExportDecomposed(self): + def testExportNondecomposed(self): dtype = torch.float32 cache = PagedKVCache( @@ -140,22 +153,32 @@ def testExportDecomposed(self): seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( self.batch_size, -1 ) - + embedding_module = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, max_seqlen=self.max_seqlen, rope_freq_base=self.rope_freq_base, ) - + class MyModule(torch.nn.Module): def forward(self, h, seq_block_ids, cache_state): - return attn.forward(h, seq_block_ids=seq_block_ids, embedding=embedding_module, start_index=0, cache_state=cache_state) - + return attn.forward( + h, + seq_block_ids=seq_block_ids, + embedding=embedding_module, + start_index=0, + cache_state=cache_state, + ) + mod = MyModule() - h = torch.rand([self.batch_size, self.max_seqlen, self.attention_head_count * self.attention_head_dim]) - mod.forward(h, - seq_block_ids, - cache_state) + h = torch.rand( + [ + self.batch_size, + self.max_seqlen, + self.attention_head_count * self.attention_head_dim, + ] + ) + mod.forward(h, seq_block_ids, cache_state) ep = torch.export.export( mod, args=(