Skip to content

Commit

Permalink
fix direct args (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Oct 25, 2024
1 parent 8993725 commit d3778ed
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def setup_cache(model, shard_count):
# Direct cache dimensions:
# 2 * transformer_block_count of...
# [bs, seq_length, attn_head_count, attn_head_dim]
dynamic_shapes = (2 * hp.block_count) * [{}]
dynamic_shapes = [None]
else:
raise NotImplementedError(f"Unsupported KV cache type: {type(model.cache)}")

Expand All @@ -148,7 +148,7 @@ def setup_cache(model, shard_count):
for i in range(llama_config.tensor_parallelism_size):
arg_affinities[i] = DeviceAffinity(str(i))

return unpacked, shard_dim, dynamic_shapes, arg_affinities
return torch.stack(unpacked), shard_dim, dynamic_shapes, arg_affinities

def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]
Expand Down Expand Up @@ -189,7 +189,7 @@ def generate_batch_prefill(bs: int):
arg_device=arg_affinities,
)
def _(model, tokens, seq_lens, seq_block_ids, cs):
cache_tensors = cs
cache_tensors = torch.unbind(cs)

sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
Expand Down

0 comments on commit d3778ed

Please sign in to comment.