Skip to content

Commit

Permalink
Remove Option for ProcessGroup and Expose backend Options to reflect …
Browse files Browse the repository at this point in the history
…the correct code structure (#132931) (#2384)

Summary:

X-link: pytorch/pytorch#135653

We introduced the dispatchable backend for a ProcessGroup and collective in pytorch/pytorch#86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"

We need to make changes to the test to make it aligned with the change.

This is try to reland D62008954 by fixing internal errors.
ghstack-source-id: 242088446

Reviewed By: wz337, H-Huang

Differential Revision: D62483294
  • Loading branch information
fduwjj authored and facebook-github-bot committed Sep 12, 2024
1 parent ff6dc0a commit 5528a62
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def variable_batch_all2all_pooled_sync(
]

with record_function("## alltoall_fwd_single ##"):
if pg._get_backend_name() == "fake":
if pg._get_backend_name() == "custom":
sharded_output_embeddings = torch.empty(
sum(output_split_sizes),
device=sharded_input_embeddings.device,
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
# https://github.com/pytorch/pytorch/issues/122788
with record_function("## all2all_data:kjt splits ##"):
input_tensor = torch.stack(input_tensors, dim=1).flatten()
if pg._get_backend_name() == "fake":
if pg._get_backend_name() == "custom":
self._output_tensor = torch.empty(
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
Expand Down Expand Up @@ -367,7 +367,7 @@ def __init__(
# TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_
# https://github.com/pytorch/pytorch/issues/122788
with record_function(f"## all2all_data:kjt {label} ##"):
if self._pg._get_backend_name() == "fake":
if self._pg._get_backend_name() == "custom":
output_tensor = torch.empty(
sum(output_split),
device=self._device,
Expand Down

0 comments on commit 5528a62

Please sign in to comment.