diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 967d07d9f..34f2091e8 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -559,7 +559,7 @@ def variable_batch_all2all_pooled_sync( ] with record_function("## alltoall_fwd_single ##"): - if pg._get_backend_name() == "fake": + if pg._get_backend_name() == "custom": sharded_output_embeddings = torch.empty( sum(output_split_sizes), device=sharded_input_embeddings.device, diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 62956ed37..4c66511ef 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -239,7 +239,7 @@ def __init__( # https://github.com/pytorch/pytorch/issues/122788 with record_function("## all2all_data:kjt splits ##"): input_tensor = torch.stack(input_tensors, dim=1).flatten() - if pg._get_backend_name() == "fake": + if pg._get_backend_name() == "custom": self._output_tensor = torch.empty( [self.num_workers * len(input_tensors)], device=input_tensors[0].device, @@ -367,7 +367,7 @@ def __init__( # TODO(ivankobzarev) Remove this dynamo condition once dynamo functional collectives remapping does not emit copy_ # https://github.com/pytorch/pytorch/issues/122788 with record_function(f"## all2all_data:kjt {label} ##"): - if self._pg._get_backend_name() == "fake": + if self._pg._get_backend_name() == "custom": output_tensor = torch.empty( sum(output_split), device=self._device,