-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Support Flash Attention #501
base: dev
Are you sure you want to change the base?
Conversation
Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level? |
I haven't taken a further look yet, is this something currently blocking other features? |
Nope, I have just been going through PRs and closing out anything that can be closed, and helping get anything else closed out. If you need some help with this, let me know. If I have time to help you out, I would be happy to. |
fails with llama-3
likely a group query attention issue, since llama-2 doesn't have the same error. |
Description
Resolves #378 by adding support for
torch.nn.functional.scaled_dot_product_attention
as found here. This implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention implementations. PyTorch attempts to automatically select the most optimal implementation based on inputs. Thank you @alan-cooney for recommending this implementation!Currently still in draft because the tolerances between model that uses a fast attention implementation and one that does not are a bit high. This is likely due to the fact
scaled_dot_product_attention
requires casting to float16 (which we then cast back to float32 after doing the fused attention). I will look into this further though and see if there is an improvement to be made here.Type of change
Please delete options that are not relevant.
Checklist: