Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Support Flash Attention #501

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from

Conversation

cmathw
Copy link
Contributor

@cmathw cmathw commented Jan 30, 2024

Description

Resolves #378 by adding support for torch.nn.functional.scaled_dot_product_attention as found here. This implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention implementations. PyTorch attempts to automatically select the most optimal implementation based on inputs. Thank you @alan-cooney for recommending this implementation!

Currently still in draft because the tolerances between model that uses a fast attention implementation and one that does not are a bit high. This is likely due to the fact scaled_dot_product_attention requires casting to float16 (which we then cast back to float32 after doing the fused attention). I will look into this further though and see if there is an improvement to be made here.

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@cmathw cmathw marked this pull request as draft January 30, 2024 21:46
@bryce13950
Copy link
Collaborator

Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level?

@cmathw
Copy link
Contributor Author

cmathw commented Apr 25, 2024

Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level?

I haven't taken a further look yet, is this something currently blocking other features?

@bryce13950
Copy link
Collaborator

Nope, I have just been going through PRs and closing out anything that can be closed, and helping get anything else closed out. If you need some help with this, let me know. If I have time to help you out, I would be happy to.

@winglian
Copy link

winglian commented May 3, 2024

fails with llama-3

  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformer_lens/components.py", line 591, in forward
    z = self.calculate_z_with_sdpa(q, k, v)  # [batch, pos, head_index, d_head]
  File "/root/miniconda3/envs/py3.10/lib/python3.10/site-packages/transformer_lens/components.py", line 785, in calculate_z_with_sdpa
    z = F.scaled_dot_product_attention(query, key, value, is_causal=True)
RuntimeError: The size of tensor a (32) must match the size of tensor b (8) at non-singleton dimension 1

likely a group query attention issue, since llama-2 doesn't have the same error.

@bryce13950 bryce13950 changed the base branch from main to dev May 23, 2024 00:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Proposal] Optionally use flash attention.
3 participants