diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index f26776917..e19340737 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -595,8 +595,8 @@ def __init__( self._lookups[index] = DistributedDataParallel( module=lookup, device_ids=( - [device] - if self._device and self._device.type == "cuda" + [self._device] + if self._device is not None and self._device.type == "cuda" else None ), process_group=env.process_group, diff --git a/torchrec/distributed/embedding_tower_sharding.py b/torchrec/distributed/embedding_tower_sharding.py index b9511c2ab..9a969a897 100644 --- a/torchrec/distributed/embedding_tower_sharding.py +++ b/torchrec/distributed/embedding_tower_sharding.py @@ -168,7 +168,7 @@ def __init__( # Hierarchical DDP self.interaction = DistributedDataParallel( module=module.interaction.to(self._device), - device_ids=[self._device], + device_ids=[self._device] if self._device is not None else None, process_group=self._intra_pg, gradient_as_bucket_view=True, broadcast_buffers=False, @@ -589,7 +589,7 @@ def __init__( # Hierarchical DDP self.interactions[i] = DistributedDataParallel( module=tower.interaction.to(self._device), - device_ids=[self._device], + device_ids=[self._device] if self._device is not None else None, process_group=self._intra_pg, gradient_as_bucket_view=True, broadcast_buffers=False, diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index ce707035c..052049f4c 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -695,8 +695,9 @@ def __init__( self._lookups[i] = DistributedDataParallel( module=lookup, device_ids=( - [device] - if self._device and (self._device.type in {"cuda", "mtia"}) + [self._device] + if self._device is not None + and (self._device.type in {"cuda", "mtia"}) else None ), process_group=env.process_group, diff --git a/torchrec/distributed/fused_embeddingbag.py b/torchrec/distributed/fused_embeddingbag.py index aadd8d923..baba793d5 100644 --- a/torchrec/distributed/fused_embeddingbag.py +++ b/torchrec/distributed/fused_embeddingbag.py @@ -70,7 +70,7 @@ def __init__( if isinstance(sharding, DpPooledEmbeddingSharding): self._lookups[index] = DistributedDataParallel( module=lookup, - device_ids=[device], + device_ids=[device] if device is not None else None, process_group=env.process_group, gradient_as_bucket_view=True, broadcast_buffers=False,