Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] LowBit Fused Attention #1071

Open
drisspg opened this issue Oct 14, 2024 · 2 comments
Open

[RFC] LowBit Fused Attention #1071

drisspg opened this issue Oct 14, 2024 · 2 comments
Assignees
Labels

Comments

@drisspg
Copy link
Contributor

drisspg commented Oct 14, 2024

Current State of OSS FP8 Operators

So far, all examples of fp8 ops (compute in fp8) are scaled matmuls that accumulate in a higher precision type. In fact, there are really only 2 classes of instructions that are supported in PTX:

  • Matmul
  • Casting

The complexity of FP8 training (which is somewhat easier for inference) is that we need to efficiently calculate scales that align the current distribution of values in a high precision tensor to what is representable in fp8.

This is easier for inference because the weight is frozen and we can pre-calculate the scale.

Inference

Before we can walk, we must crawl. Let's look at what's available for inference, which is a strictly easier problem.

All of these are using TensorWise scaling.

Kernels

1. FAv3

  • GitHub Link
  • Still actively being developed.
  • Does not appear to support any scaling format As of Dao-AILab/flash-attention@c92ca63 q,k,v scales have been added to the kernel and interface.
  • The above is true but the same sm_scale trick + out*v_scale trick can be used for hp output | No longer needed

2. FlashInfer

Prefill
  • GitHub Link
  • It appears that although inputs can be fp8, they are casted up under the hood, and not using a scale.
BatchedPrefill with KVCache
  • GitHub Link
  • More important in large scale inference workloads.
  • They accept a key and value scale, but query does not allow for scaling.
  • Interestingly, they roll the k scale into sm_scale for pre-softmax scaling.
  • They do the V scaling by multiplying the output.
  • The ragged impl casts fp8 up.
Decode
  • GitHub Link
  • Explicit support for q, k, v scaling.
  • q and k scaling are rolled into sm_scale.
  • v scale is done by multiplying the final output.

TLDR: Uses a neat strategy for fusing scaling into existing kernels.

3. VLLM

  • Static kv cache generator script: GitHub Link
  • Uses ammo for the tensor configs, they have an atq config.
  • Hardcodes e4m3.
  • Looks like they adopt flashinfer strategy: GitHub Link
  • Looks like it's only supported for blocksparse(paged) impl: GitHub Link

4. FlexAttention

  • This is actually pretty straightforward to support if we use the same technique as listed above.
  • We can update our sm_scale roll in the q and k scale:
    sm_scale = q_fp8._scale.reciprocal() * k_fp8._scale.reciprocal()
    Note: This currently fails since we expect input to be on host, but we can fix, or use score_mod (fixing is better).
  • This only works if the output is in HP and not float8; otherwise, we would lose precision in the cast from softmax(qk) @ v since the scale would be applied after.
  • Question: Can we epilogue fuse this?
  • Interesting performance results from first implmentation:fp8_bench.py
float16_time=6.860068321228027
fp8_time=10.829721450805664

This is idealized too since not accounting for casting overhead or epilogue kernel

5. Transformer Engine

  • They have all the scales .. cudnn call. This is used for training and has all the intermediate scales saved for backward.
  • Bakes in delayed scaling by nature of the algo, (directly appends to amax_history)
  • Thin wrapper around cuDNN's API: API

6. TensorRt

Some Code Runs

Flex Experiments

from functools import partial
from typing import Optional
import torch
import torch.nn.functional as F
import math
from tabulate import tabulate
from torch.nn.attention.flex_attention import flex_attention
from triton.testing import do_bench

torch.set_default_device("cuda")
torch.manual_seed(0)
torch._dynamo.config.cache_size_limit = 1000
torch._inductor.config.triton.unique_kernel_names = True

# For better performance, you can use:
data_type = torch.float16

def main(do_fp16=False, max_autotune: bool = False):
    try:
        from torchao.float8.float8_tensor import Float8Tensor, hp_tensor_and_scale_to_float8
        from torchao.float8.float8_utils import tensor_to_scale
    except ImportError:
        raise ImportError("Fp8 example needs torchao to run!")

    data_type = torch.float16
    make_tensor = partial(torch.rand, device="cuda", dtype=data_type)
    input_size = (4, 16, 8192, 128)
    q, k, v = make_tensor(input_size), make_tensor(input_size), make_tensor(input_size)

    if max_autotune:
        flex = torch.compile(
            flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs"
        )
    else:
        flex = torch.compile(flex_attention, dynamic=False)

    if do_fp16:
        float16_time = do_bench(lambda: flex(q, k, v))
        print(f"{float16_time=}")

    # Maximal perf time
    q_fp8 = hp_tensor_and_scale_to_float8(q, tensor_to_scale(q, torch.float8_e4m3fn))
    k_fp8 = hp_tensor_and_scale_to_float8(k, tensor_to_scale(k, torch.float8_e4m3fn))
    v_fp8 = hp_tensor_and_scale_to_float8(v, tensor_to_scale(v, torch.float8_e4m3fn))
    sm_scale = 1.0 / math.sqrt(64)
    sm_scale *= q_fp8._scale.reciprocal() * k_fp8._scale.reciprocal()
    # Work around for now
    sm_scale = sm_scale.item()
    q_fp8_data = q_fp8._data
    k_fp8_data = k_fp8._data
    v_fp8_data = v_fp8._data
    
    flex(q_fp8_data, k_fp8_data, v_fp8_data, scale=sm_scale)
    fp8_time = do_bench(lambda: flex(q_fp8_data, k_fp8_data, v_fp8_data, scale=sm_scale))
    print(f"{fp8_time=}")

if __name__ == "__main__":
    try:
        from jsonargparse import CLI
    except ImportError:
        raise ImportError("Be sure to run: pip install -e .'[viz]'")
    CLI(main)
@jainapurva jainapurva added the rfc label Oct 14, 2024
@gau-nernst
Copy link
Collaborator

Would you be interested to consider INT8 attention too? #952 (https://github.com/INT-FlashAttention2024/INT-FlashAttention)

There are also other triton/cuda kernels for int8 attention floating around but I haven't looked into them closely.

@drisspg
Copy link
Contributor Author

drisspg commented Oct 16, 2024

@gau-nernst Still working through this RFC not nearly complete yet but yeah going to add a section on int8 attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants