Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] use flash_attention from cuda-free #12

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sijiac
Copy link
Contributor

@sijiac sijiac commented Sep 24, 2024

By switching to use the attention from the cuda-free repo, the Triton attention now works well for the kernels repo

The missing part of the attention kernel of kernels repo is it doesn’t support the decoding case, where the length of Q and the length of K is different for the same batch

python3 -m main llama_chat_completion --profile=False --benchmark=False --ckpt_dir="/home/sijiac/models/Meta-Llama-3-8B-Instruct/" --tokenizer_path="/home/sijiac/models/Meta-Llama-3-8B-Instruct/tokenizer.model" --use_triton=True

@adamomainz
Copy link
Collaborator

can you run with benchmarking turned on and see the difference? would be curious to see the attention specific latency here :) in that case you dont need to specify use_triton since it will run with both cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants