From cc3e0eec45978d21f3f9eb93cb0abee3d1002aca Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Fri, 11 Oct 2024 06:19:12 -0700 Subject: [PATCH] Skip fsdp2 import if running with deploy Summary: title, this breaks deploy models Differential Revision: D64237929 --- torchrec/distributed/train_pipeline/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index eba7db8bb..a9160ce90 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -31,9 +31,17 @@ Union, ) +import torch from torch import distributed as dist -from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +if not torch._running_with_deploy(): + from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +else: + + class FSDP2: + pass + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.fx.immutable_collections import ( immutable_dict as fx_immutable_dict,