diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 3173e710d..85de6e3b8 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -21,6 +21,7 @@ Deque, Dict, Generic, + Iterable, Iterator, List, Optional, @@ -911,17 +912,40 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None: if cast(int, context.index) % 2 == 0 else self._embedding_odd_streams ) - for stream, emb_tensors, detached_emb_tensors in zip( + assert len(context.embedding_features) == len(context.embedding_tensors) + for stream, emb_tensors, embedding_features, detached_emb_tensors in zip( streams, context.embedding_tensors, + context.embedding_features, context.detached_embedding_tensors, ): with self._stream_context(stream): grads = [tensor.grad for tensor in detached_emb_tensors] if stream: stream.wait_stream(default_stream) - # pyre-ignore - torch.autograd.backward(emb_tensors, grads) + # Some embeddings may never get used in the final loss computation, + # so the grads will be `None`. If we don't exclude these, it will fail + # with error: "grad can be implicitly created only for scalar outputs" + # Alternatively, if the tensor has only 1 element, pytorch can still + # figure out how to do autograd + embs_to_backprop, grads_to_use, invalid_features = [], [], [] + assert len(embedding_features) == len(emb_tensors) + for features, tensor, grad in zip( + embedding_features, emb_tensors, grads + ): + if tensor.numel() == 1 or grad is not None: + embs_to_backprop.append(tensor) + grads_to_use.append(grad) + else: + if isinstance(features, Iterable): + invalid_features.extend(features) + else: + invalid_features.append(features) + if invalid_features and context.index == 0: + logger.warning( + f"SemiSync, the following features have no gradients: {invalid_features}" + ) + torch.autograd.backward(embs_to_backprop, grads_to_use) def copy_batch_to_gpu( self, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index eba7db8bb..db6182caa 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -128,6 +128,7 @@ class PrefetchTrainPipelineContext(TrainPipelineContext): class EmbeddingTrainPipelineContext(TrainPipelineContext): embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict) embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) + embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list) detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list) @@ -408,6 +409,8 @@ def __call__(self, *input, **kwargs) -> Awaitable: # pyre-ignore [16] self._context.embedding_tensors.append(tensors) # pyre-ignore [16] + self._context.embedding_features.append(list(embeddings.keys())) + # pyre-ignore [16] self._context.detached_embedding_tensors.append(detached_tensors) else: assert isinstance(embeddings, KeyedTensor) @@ -418,6 +421,13 @@ def __call__(self, *input, **kwargs) -> Awaitable: tensors.append(tensor) detached_tensors.append(detached_tensor) self._context.embedding_tensors.append(tensors) + # KeyedTensor is returned by EmbeddingBagCollections and its variants + # KeyedTensor holds dense data from multiple features and .values() + # returns a single concatenated dense tensor. To ensure that + # context.embedding_tensors[i] has the same length as + # context.embedding_features[i], we pass in a list with a single item: + # a list containing all the embedding feature names. + self._context.embedding_features.append([list(embeddings.keys())]) self._context.detached_embedding_tensors.append(detached_tensors) return LazyNoWait(embeddings)