Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] torchao.float8 should work properly on ROCm #1066

Open
OrenLeung opened this issue Oct 12, 2024 · 2 comments
Open

[ROCm] torchao.float8 should work properly on ROCm #1066

OrenLeung opened this issue Oct 12, 2024 · 2 comments

Comments

@OrenLeung
Copy link

Hi @hongxiayang @hliuca ,

It seems like float8 training using torchao.float8 is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supporting torchao.float8?

Attempting Install From Nightly

From using the ROCm nightly torchao wheel, the torchao.float8 module is not present

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/rocm6.2
python -c "import torchao; print(dir(torchao))"
['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'apply_dynamic_quant', 'apply_weight_only_int8_quant', 'dtypes', 'kernel', 'quantization']

Attempting Install From Source

From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
pip install git+https://github.com/pytorch/ao.git

Eager Mode Error

   tensor_out = addmm_float8_unwrapped(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 55, in addmm_float8_unwrapped
    output = torch._scaled_mm(
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.

Compile Mode Error

    tmp15 = 448.0
    tmp16 = triton_helpers.minimum(tmp14, tmp15)
    tmp17 = tmp16.to(tl.float8e4nv)
            ^

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Reprod Script is From The torchao.float8 README Example

import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loop
for _ in range(10):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()
@drisspg
Copy link
Contributor

drisspg commented Oct 14, 2024

Can you make sure to set:

use_fnuz_dtype = False
to True

@vkuzo
Copy link
Contributor

vkuzo commented Oct 16, 2024

Here are my thoughts on what we need to do to enable ROCm support for float8:

  1. ensure torch._scaled_mm's path for ROCm is fast an accurate
  2. there is a config setting for the nuz float8 flavors here (
    use_fnuz_dtype = False
    ), but it's not tested at the moment. We should enable testing across all of our test suite, first locally and then in CI.
  3. get e2e performance/accuracy to be good, measured by benchmarks on real workloads

@vkuzo vkuzo changed the title [ROCm] float8 does not work [ROCm] torchao.float8 should work properly on ROCm Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants