Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchao does not work with HSDP #1086

Open
goldhuang opened this issue Oct 15, 2024 · 2 comments
Open

Torchao does not work with HSDP #1086

goldhuang opened this issue Oct 15, 2024 · 2 comments

Comments

@goldhuang
Copy link

goldhuang commented Oct 15, 2024

[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/distributed/_composable/fsdp/_fsdp_param.py", line 653, in all_gather_inputs
[rank0]:     ) = sharded_local_tensor.fsdp_pre_all_gather(self.mesh_info.mesh)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torchao/float8/fsdp_utils.py", line 218, in fsdp_pre_all_gather
[rank0]:     float8_tensor = hp_tensor_to_float8_dynamic(
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torchao/float8/float8_scaling_utils.py", line 62, in hp_tensor_to_float8_dynamic
[rank0]:     scale = tensor_to_scale(
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 138, in tensor_to_scale
[rank0]:     amax = tensor_to_amax(
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 123, in tensor_to_amax
[rank0]:     pg = device_mesh.get_group() if device_mesh is not None else None
[rank0]:   File "/opt/venv/lib/python3.10/site-packages/torch/distributed/device_mesh.py", line 694, in get_group
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: ('Found the DeviceMesh have 2 dimensions', 'Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.', 'If you want to get the list of all the ProcessGroups in the DeviceMesh,please use `get_all_groups()` instead.')

I see runtime error if devicemesh has 2 dimensions. I'm with torch2.5.

@jerryzh168
Copy link
Contributor

can you paste the repro code as well

@goldhuang
Copy link
Author

goldhuang commented Oct 17, 2024

I changed to delayed scaling from dynamic scaling, and updated triton to 938e388e7dbbe62952ea19229803afe8fa6baefd. Then the issue won't reproduce. I'm not sure if it's related to triton version or dynamic scaling.
If anyone else run into the same issue, you can avoid it as what I did or add more information here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants