Skip to content

Commit

Permalink
[sharktank] Fix write_timestep for direct cache (nod-ai#354)
Browse files Browse the repository at this point in the history
Needed to materialize the batch index write dimension to write for
timestep.
  • Loading branch information
rsuderman authored Oct 29, 2024
1 parent fb4be27 commit 8c9d454
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def write_timestep(
update_count = len(cache_partitions)

for b in range(bs):
row_index = torch.tensor(b, dtype=torch.int64)
row_start_pos = seq_positions[row_index]
row_index = torch.tensor([b], dtype=torch.int64)
row_start_pos = seq_positions[row_index].unsqueeze(0)

for i, update in enumerate(cache_partitions):
cache = state[transformer_block_index * update_count + i]
Expand Down

0 comments on commit 8c9d454

Please sign in to comment.