Skip to content

Commit

Permalink
Ignore non-existent grad tensors for semi-sync (#2490)
Browse files Browse the repository at this point in the history
Summary:

During semi-sync, we need to ignore embedding tensor grads that are `None`, otherwise `torch.autograd.backward` will fail with error `grad can be implicitly created only for scalar outputs`. This is a valid scenario if, for example, the embeddings are looked up but never used for the final loss computation

Reviewed By: che-sh

Differential Revision: D63379382
  • Loading branch information
sarckk authored and facebook-github-bot committed Oct 18, 2024
1 parent d1a2990 commit 953473e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
30 changes: 27 additions & 3 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Deque,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -911,17 +912,40 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams
)
for stream, emb_tensors, detached_emb_tensors in zip(
assert len(context.embedding_features) == len(context.embedding_tensors)
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
streams,
context.embedding_tensors,
context.embedding_features,
context.detached_embedding_tensors,
):
with self._stream_context(stream):
grads = [tensor.grad for tensor in detached_emb_tensors]
if stream:
stream.wait_stream(default_stream)
# pyre-ignore
torch.autograd.backward(emb_tensors, grads)
# Some embeddings may never get used in the final loss computation,
# so the grads will be `None`. If we don't exclude these, it will fail
# with error: "grad can be implicitly created only for scalar outputs"
# Alternatively, if the tensor has only 1 element, pytorch can still
# figure out how to do autograd
embs_to_backprop, grads_to_use, invalid_features = [], [], []
assert len(embedding_features) == len(emb_tensors)
for features, tensor, grad in zip(
embedding_features, emb_tensors, grads
):
if tensor.numel() == 1 or grad is not None:
embs_to_backprop.append(tensor)
grads_to_use.append(grad)
else:
if isinstance(features, Iterable):
invalid_features.extend(features)
else:
invalid_features.append(features)
if invalid_features and context.index == 0:
logger.warning(
f"SemiSync, the following features have no gradients: {invalid_features}"
)
torch.autograd.backward(embs_to_backprop, grads_to_use)

def copy_batch_to_gpu(
self,
Expand Down
10 changes: 10 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
class EmbeddingTrainPipelineContext(TrainPipelineContext):
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list)
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)


Expand Down Expand Up @@ -408,6 +409,8 @@ def __call__(self, *input, **kwargs) -> Awaitable:
# pyre-ignore [16]
self._context.embedding_tensors.append(tensors)
# pyre-ignore [16]
self._context.embedding_features.append(list(embeddings.keys()))
# pyre-ignore [16]
self._context.detached_embedding_tensors.append(detached_tensors)
else:
assert isinstance(embeddings, KeyedTensor)
Expand All @@ -418,6 +421,13 @@ def __call__(self, *input, **kwargs) -> Awaitable:
tensors.append(tensor)
detached_tensors.append(detached_tensor)
self._context.embedding_tensors.append(tensors)
# KeyedTensor is returned by EmbeddingBagCollections and its variants
# KeyedTensor holds dense data from multiple features and .values()
# returns a single concatenated dense tensor. To ensure that
# context.embedding_tensors[i] has the same length as
# context.embedding_features[i], we pass in a list with a single item:
# a list containing all the embedding feature names.
self._context.embedding_features.append([list(embeddings.keys())])
self._context.detached_embedding_tensors.append(detached_tensors)

return LazyNoWait(embeddings)
Expand Down

0 comments on commit 953473e

Please sign in to comment.