Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix difference of LLM export for the direct vs paged cache #347

Merged
merged 3 commits into from
Oct 28, 2024

Conversation

sogartar
Copy link
Contributor

Before work on unifying the cache interfaces there are some differences between sharded, direct and paged caches.
The direct cache uses a list of tensors for each transformer block while paged cache has on slab and paged sharded expect a list of shards.

@renxida
Copy link
Contributor

renxida commented Oct 28, 2024

thanks for doing this!

is this ready to merge? Would so love to have it in main asap - blocked by this and currently using some hacky solutions.

(Please make sure this works for
bs=1
bs=1,4
bs=4
)

@renxida
Copy link
Contributor

renxida commented Oct 28, 2024

rn when i try to run this i get

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/xidaren2/SHARK-Platform/sharktank/sharktank/examples/export_paged_llm_v1.py", line 325, in <module>
    main()
  File "/home/xidaren2/SHARK-Platform/sharktank/sharktank/examples/export_paged_llm_v1.py", line 307, in main
    generate_batch_prefill(bs)
  File "/home/xidaren2/SHARK-Platform/sharktank/sharktank/examples/export_paged_llm_v1.py", line 163, in generate_batch_prefill
    cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
                                                                   ^^^^^^^^^^^^
  File "/home/xidaren2/SHARK-Platform/sharktank/sharktank/examples/export_paged_llm_v1.py", line 147, in setup_cache
    return torch.stack(cache_state), shard_dim, dynamic_shapes, arg_affinities
                                     ^^^^^^^^^
UnboundLocalError: cannot access local variable 'shard_dim' where it is not associated with a value

@renxida
Copy link
Contributor

renxida commented Oct 28, 2024

might need to manually test this because this file isn't exercised by the CI

@sogartar
Copy link
Contributor Author

@renxida thank you for catching that. No matter how small of a change, I can always make a mistake. After the fix I tested the direct cache path also.

Before work on unifying the cache interfaces there are some differences
between sharded, direct and paged caches.
The direct cache uses a list of tensors for each transformer block
while paged cache has on slab and paged sharded expect a list of shards.
@sogartar sogartar force-pushed the fix-cache-in-sharded-llama-export branch from e982607 to 62a037d Compare October 28, 2024 18:07
@sogartar sogartar merged commit 98392d0 into nod-ai:main Oct 28, 2024
3 checks passed
@renxida
Copy link
Contributor

renxida commented Oct 28, 2024

Ack export works but now compile doesn't

Saving to '/home/xidaren2/xshortfin/goldens/exported_llama_model/model.mlir'

  • iree-compile /home/xidaren2/xshortfin/goldens/exported_llama_model/model.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx1100 -o /home/xidaren2/xshortfin/goldens/exported_llama_model/model.vmfb
    /home/xidaren2/xshortfin/goldens/exported_llama_model/model.mlir:12158:12: error: 'tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
    %357 = torch.aten.index_put %348, %356, %355, %false_114 : !torch.vtensor<[1,2048,32,100],f16>, !torch.list<optional>, !torch.vtensor<[32,100],f16>, !torch.bool -> !torch.vtensor<[1,2048,32,100],f16>
    ^
    /home/xidaren2/xshortfin/goldens/exported_llama_model/model.mlir:12158:12: note: see current operation:
    %677 = "tm_tensor.scatter"(%675, %676, %674) <{dimension_map = array<i64: 0, 1>, operandSegmentSizes = array<i32: 2, 1>, unique_indices = false}> ({
    ^bb0(%arg107: f16, %arg108: f16):
    "tm_tensor.yield"(%arg107) : (f16) -> ()
    }) : (tensor<32x1x1x32x100xf16>, tensor<1x2xi32>, tensor<1x2048x32x100xf16>) -> tensor<1x2048x32x100xf16>

@dan-garvey
Copy link
Member

@renxida this is the issue I'm working on

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants