Skip to content

Commit

Permalink
Call .wait_tensor() in compiled region for dist.Work created in eager…
Browse files Browse the repository at this point in the history
… region

Summary:
In compiled region, instead of calling `dist.Work.wait()`, we will call `torch.ops._c10d_functional.wait_tensor()` on the dist.Work's output tensor. This way, we can capture the `wait_tensor()` op within the torch.compile graph (instead of graph-breaking on `dist.Work.wait()`), and the tensor will be waited on properly within the graph.

This diff also depends on pytorch/pytorch#137763 to function properly.

Differential Revision: D64275115
  • Loading branch information
yf225 authored and facebook-github-bot committed Oct 13, 2024
1 parent b6e784e commit a7df96c
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ def _wait_impl(self) -> W:
return ret


def wait_req(req: Request[W]) -> None:
if is_torchdynamo_compiling():
assert req.tensor is not None
torch.ops._c10d_functional.wait_tensor(req.tensor)
else:
assert isinstance(req.req, dist.Work)
req.req.wait()


@dataclass
class All2AllPooledInfo(object):
"""
Expand Down Expand Up @@ -1334,7 +1343,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
myreq = ctx.myreq
a2ai = myreq.a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_output = myreq.tensor
dim_sum_per_rank = a2ai.dim_sum_per_rank
Expand Down Expand Up @@ -1368,7 +1377,7 @@ def forward(
a2ai = myreq.a2ai
ctx.a2ai = a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
sharded_output_embeddings = myreq.tensor
myreq.req = None
myreq.tensor = None
Expand Down Expand Up @@ -1573,9 +1582,9 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
myreq = ctx.myreq
a2ai = myreq.a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
if isinstance(myreq.req, dist.Work):
myreq.req.wait()
wait_req(myreq)

myreq.req = None
grad_output = myreq.tensor
Expand Down Expand Up @@ -1606,7 +1615,7 @@ def forward(
ctx.a2ai = a2ai
assert myreq.req is not None
if isinstance(myreq.req, dist.Work):
myreq.req.wait()
wait_req(myreq)
sharded_output_embeddings = myreq.tensor
myreq.req = None
myreq.tensor = None
Expand Down Expand Up @@ -1797,7 +1806,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
a2ai.permuted_lengths_after_sparse_data_all2all
)
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
sharded_grad_input = myreq.tensor
if a2ai.codecs is not None:
codecs = none_throws(a2ai.codecs)
Expand Down Expand Up @@ -1845,7 +1854,7 @@ def forward(
D = a2ai.embedding_dim
ctx.a2ai = a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
sharded_output_embeddings = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -1952,7 +1961,7 @@ def forward(
def backward(ctx, *grad_output):
a2ai = ctx.a2ai
myreq = ctx.myreq
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_input = myreq.tensor
if a2ai.codecs is not None:
Expand Down Expand Up @@ -1980,7 +1989,7 @@ def forward(
a2ai = myreq.a2ai
ctx.a2ai = a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
output = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2067,7 +2076,7 @@ def forward(
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_inputs = list(myreq.tensor)
rsi = myreq.rsi
Expand Down Expand Up @@ -2095,7 +2104,7 @@ def forward(
*dummy_tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
output = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2174,7 +2183,7 @@ def forward(
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_inputs = myreq.tensor
rsi = myreq.rsi
Expand All @@ -2199,7 +2208,7 @@ def forward(
*dummy_Tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
output = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2270,7 +2279,7 @@ def forward(
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
agi = myreq.agi
grad_input = myreq.tensor
Expand All @@ -2296,7 +2305,7 @@ def forward(
*dummy_tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
outputs = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2382,7 +2391,7 @@ def forward(
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_input = myreq.tensor
rsi = myreq.rsi
Expand All @@ -2407,7 +2416,7 @@ def forward(
*dummy_tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
# pyre-ignore
output: torch.Tensor = myreq.tensor
Expand Down

0 comments on commit a7df96c

Please sign in to comment.