Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory_efficient_attention faster than flash attention 2 backend? #1180

Open
asahni04 opened this issue Dec 19, 2024 · 5 comments
Open

memory_efficient_attention faster than flash attention 2 backend? #1180

asahni04 opened this issue Dec 19, 2024 · 5 comments

Comments

@asahni04
Copy link

asahni04 commented Dec 19, 2024

❓ Questions and Help

expected other way around. what is the fastest kernel i can use here?

            q = q.to(dtype)  # XFormers needs manual casting of the operators
            k = k.to(dtype)
            v = v.to(dtype)
            x = memory_efficient_attention(
                q,
                k,
                v,
                p=self.attn_drop.p if self.training else 0.0,
                op=self.efficient_attention_ops,
            )
            vs

            q,k,v = map(lambda t: rearrange(t, "b n h d -> b h n d", d=self.head_dim), (q, k, v))
            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
                    x = scaled_dot_product_attention(
                        q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0
                    )  # Scale is automatically computed by the torch implementation


            
@danthe3rd
Copy link
Contributor

danthe3rd commented Dec 27, 2024

Hi,
We now added support for Flash v3, which is not yet supported in PyTorch (cc @drisspg @tridao). Once PyTorch supports Flash v3, I assume both will take the same time.
I'm assuming you are running on H100, in bf16 or fp16.
Also, what's the value of self.efficient_attention_ops ?

@drisspg
Copy link
Contributor

drisspg commented Dec 27, 2024

Started to work on the pre-reqs: pytorch/pytorch#143515

But yeah as of right now the most performant kernel we have in PyTorch is the CUDNN backend on h100

@asahni04
Copy link
Author

asahni04 commented Dec 27, 2024

@danthe3rd yes i'm using bf16 on H100. i tried with the flash attention 2 replacement of memory efficient attention but couldn't see expected speedup. flash attention 3 from the official repo https://github.com/Dao-AILab/flash-attention is much faster but not a out of the box replacement and requires finetuning.

            self.efficient_attention_ops = (
                xformers_efficient_attention_fw,
                xformers_efficient_attention_bw,
            )

@drisspg i see so SDPBackend. CUDNN_ATTENTION is the fastest? even faster than FA-2?? What About A100 and A10s?

any other way to speed

@drisspg
Copy link
Contributor

drisspg commented Dec 27, 2024

SDPBackend. CUDNN_ATTENTION is the fastest implementation currently supported for SDPA and is meant for h100 + gpus. For A100 and A10s FAv2 is still your best bet

is much faster but not a out of the box replacement and requires finetuning.

Curious to learn more about this line

@danthe3rd
Copy link
Contributor

For xFormers, just set self.efficient_attention_ops = None.
For SDPA, probably you should try without specifying the backend, and let pytorcb select the best for you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants