You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems like float8 training using torchao.float8 is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supporting torchao.float8?
Attempting Install From Nightly
From using the ROCm nightly torchao wheel, the torchao.float8 module is not present
From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.
tensor_out = addmm_float8_unwrapped(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torchao/float8/float8_python_api.py", line 55, in addmm_float8_unwrapped
output = torch._scaled_mm(
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/hip/HIPDataType.h":102, please report a bug to PyTorch. Cannot convert ScalarType Float8_e4m3fn to hipDataType.
Compile Mode Error
tmp15 = 448.0
tmp16 = triton_helpers.minimum(tmp14, tmp15)
tmp17 = tmp16.to(tl.float8e4nv)
^
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
nn.Linear(4096, 128),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last moduleif fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
# enable torch.compile for competitive performance
m = torch.compile(m)
# toy training loopfor_in range(10):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
The text was updated successfully, but these errors were encountered:
Hi @hongxiayang @hliuca ,
It seems like float8 training using
torchao.float8
is not support at the moment. Is there a different library or code path I should be using for float8 training or what the timelines around ROCm supportingtorchao.float8
?Attempting Install From Nightly
From using the ROCm nightly torchao wheel, the
torchao.float8
module is not presentAttempting Install From Source
From installing from source, I run into an triton datatype issue. If I disable torch.compile, then i run into the eager mode fp8 dtype not being the AMD format but the Nvidia fp8 format.
Eager Mode Error
Compile Mode Error
Reprod Script is From The torchao.float8 README Example
The text was updated successfully, but these errors were encountered: