You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So far, all examples of fp8 ops (compute in fp8) are scaled matmuls that accumulate in a higher precision type. In fact, there are really only 2 classes of instructions that are supported in PTX:
Matmul
Casting
The complexity of FP8 training (which is somewhat easier for inference) is that we need to efficiently calculate scales that align the current distribution of values in a high precision tensor to what is representable in fp8.
This is easier for inference because the weight is frozen and we can pre-calculate the scale.
Inference
Before we can walk, we must crawl. Let's look at what's available for inference, which is a strictly easier problem.
Note: This currently fails since we expect input to be on host, but we can fix, or use score_mod (fixing is better).
This only works if the output is in HP and not float8; otherwise, we would lose precision in the cast from softmax(qk) @ v since the scale would be applied after.
Question: Can we epilogue fuse this?
Interesting performance results from first implmentation:fp8_bench.py
Current State of OSS FP8 Operators
So far, all examples of fp8 ops (compute in fp8) are scaled matmuls that accumulate in a higher precision type. In fact, there are really only 2 classes of instructions that are supported in PTX:
The complexity of FP8 training (which is somewhat easier for inference) is that we need to efficiently calculate scales that align the current distribution of values in a high precision tensor to what is representable in fp8.
This is easier for inference because the weight is frozen and we can pre-calculate the scale.
Inference
Before we can walk, we must crawl. Let's look at what's available for inference, which is a strictly easier problem.
All of these are using TensorWise scaling.
Kernels
1. FAv3
Does not appear to support any scaling formatAs of Dao-AILab/flash-attention@c92ca63 q,k,v scales have been added to the kernel and interface.2. FlashInfer
Prefill
BatchedPrefill with KVCache
Decode
TLDR: Uses a neat strategy for fusing scaling into existing kernels.
3. VLLM
4. FlexAttention
This is idealized too since not accounting for casting overhead or epilogue kernel
5. Transformer Engine
6. TensorRt
Some Code Runs
Flex Experiments
The text was updated successfully, but these errors were encountered: