Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Oct 20, 2024
1 parent d107b30 commit 9e0deb1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 27 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/export_layer/export_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def main():
head_dim=llama_config.hp.attn_head_dim,
head_count_kv=llama_config.hp.attention_head_count_kv,
rms_epsilon=llama_config.hp.attention_layer_norm_rms_epsilon,
attention_kernel=args.attention_kernel
attention_kernel=args.attention_kernel,
)

def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
Expand Down
13 changes: 7 additions & 6 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ def forward(
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride

print("all",
start_positions,
cache_state,
xk_temp,
xv_temp,
)
print(
"all",
start_positions,
cache_state,
xk_temp,
xv_temp,
)
if self.cache.is_paged:
xk, xv = self.transact_cache_paged(
xk_cache_update=xk,
Expand Down
63 changes: 43 additions & 20 deletions sharktank/tests/layers/paged_llama_attention_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
import torch

from iree.turbine import aot
from sharktank.layers import PagedLlamaAttentionBlock, PagedKVCache, RotaryEmbeddingLayer
from sharktank.layers import (
PagedLlamaAttentionBlock,
PagedKVCache,
RotaryEmbeddingLayer,
)
from sharktank.layers.testing import make_llama_attention_block_theta
from sharktank.types.tensors import DefaultPrimitiveTensor


class PagedLlamaAttentionBlockTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(12345)
Expand All @@ -38,7 +43,6 @@ def setUp(self):
self.batch_size = 3
self.start_index = 0


def testExportDecomposed(self):
dtype = torch.float32

Expand Down Expand Up @@ -74,22 +78,32 @@ def testExportDecomposed(self):
seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)

embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

class MyModule(torch.nn.Module):
def forward(self, h, seq_block_ids, cache_state):
return attn.forward(h, seq_block_ids=seq_block_ids, embedding=embedding_module, start_index=0, cache_state=cache_state)

return attn.forward(
h,
seq_block_ids=seq_block_ids,
embedding=embedding_module,
start_index=0,
cache_state=cache_state,
)

mod = MyModule()
h = torch.rand([self.batch_size, self.max_seqlen, self.attention_head_count * self.attention_head_dim])
mod.forward(h,
seq_block_ids,
cache_state)
h = torch.rand(
[
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
]
)
mod.forward(h, seq_block_ids, cache_state)
ep = torch.export.export(
mod,
args=(
Expand All @@ -104,8 +118,7 @@ def forward(self, h, seq_block_ids, cache_state):
output.save_mlir("temp.mlir")
self.assertNotIn("scaled_dot_product_attention", asm)


def testExportDecomposed(self):
def testExportNondecomposed(self):
dtype = torch.float32

cache = PagedKVCache(
Expand Down Expand Up @@ -140,22 +153,32 @@ def testExportDecomposed(self):
seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view(
self.batch_size, -1
)

embedding_module = RotaryEmbeddingLayer(
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seqlen,
rope_freq_base=self.rope_freq_base,
)

class MyModule(torch.nn.Module):
def forward(self, h, seq_block_ids, cache_state):
return attn.forward(h, seq_block_ids=seq_block_ids, embedding=embedding_module, start_index=0, cache_state=cache_state)

return attn.forward(
h,
seq_block_ids=seq_block_ids,
embedding=embedding_module,
start_index=0,
cache_state=cache_state,
)

mod = MyModule()
h = torch.rand([self.batch_size, self.max_seqlen, self.attention_head_count * self.attention_head_dim])
mod.forward(h,
seq_block_ids,
cache_state)
h = torch.rand(
[
self.batch_size,
self.max_seqlen,
self.attention_head_count * self.attention_head_dim,
]
)
mod.forward(h, seq_block_ids, cache_state)
ep = torch.export.export(
mod,
args=(
Expand Down

0 comments on commit 9e0deb1

Please sign in to comment.