diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index eba7db8bb..a9160ce90 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -31,9 +31,17 @@ Union, ) +import torch from torch import distributed as dist -from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +if not torch._running_with_deploy(): + from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +else: + + class FSDP2: + pass + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.immutable_collections import ( immutable_dict as fx_immutable_dict,