You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 653, in all_gather_inputs
[rank0]: ) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torchao/float8/fsdp_utils.py", line 218, in fsdp_pre_all_gather
[rank0]: float8_tensor = hp_tensor_to_float8_dynamic(
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torchao/float8/float8_scaling_utils.py", line 62, in hp_tensor_to_float8_dynamic
[rank0]: scale = tensor_to_scale(
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 138, in tensor_to_scale
[rank0]: amax = tensor_to_amax(
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 123, in tensor_to_amax
[rank0]: pg = device_mesh.get_group() if device_mesh is not None else None
[rank0]: File "/opt/venv/lib/python3.10/site-packages/torch/distributed/device_mesh.py", line 694, in get_group
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: ('Found the DeviceMesh have 2 dimensions', 'Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.', 'If you want to get the list of all the ProcessGroups in the DeviceMesh,please use `get_all_groups()` instead.')
I see runtime error if devicemesh has 2 dimensions. I'm with torch2.5.
The text was updated successfully, but these errors were encountered:
I changed to delayed scaling from dynamic scaling, and updated triton to 938e388e7dbbe62952ea19229803afe8fa6baefd. Then the issue won't reproduce. I'm not sure if it's related to triton version or dynamic scaling.
If anyone else run into the same issue, you can avoid it as what I did or add more information here.
I see runtime error if devicemesh has 2 dimensions. I'm with torch2.5.
The text was updated successfully, but these errors were encountered: