Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call .wait_tensor() in compiled region for dist.Work created in eager region #2485

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ def _wait_impl(self) -> W:
return ret


def wait_req(req: Request[W]) -> None:
if is_torchdynamo_compiling():
assert req.tensor is not None
torch.ops._c10d_functional.wait_tensor(req.tensor)
else:
assert isinstance(req.req, dist.Work)
req.req.wait()


@dataclass
class All2AllPooledInfo(object):
"""
Expand Down Expand Up @@ -1334,7 +1343,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
myreq = ctx.myreq
a2ai = myreq.a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_output = myreq.tensor
dim_sum_per_rank = a2ai.dim_sum_per_rank
Expand Down Expand Up @@ -1368,7 +1377,7 @@ def forward(
a2ai = myreq.a2ai
ctx.a2ai = a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
sharded_output_embeddings = myreq.tensor
myreq.req = None
myreq.tensor = None
Expand Down Expand Up @@ -1573,9 +1582,9 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
myreq = ctx.myreq
a2ai = myreq.a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
if isinstance(myreq.req, dist.Work):
myreq.req.wait()
wait_req(myreq)

myreq.req = None
grad_output = myreq.tensor
Expand Down Expand Up @@ -1606,7 +1615,7 @@ def forward(
ctx.a2ai = a2ai
assert myreq.req is not None
if isinstance(myreq.req, dist.Work):
myreq.req.wait()
wait_req(myreq)
sharded_output_embeddings = myreq.tensor
myreq.req = None
myreq.tensor = None
Expand Down Expand Up @@ -1797,7 +1806,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]:
a2ai.permuted_lengths_after_sparse_data_all2all
)
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
sharded_grad_input = myreq.tensor
if a2ai.codecs is not None:
codecs = none_throws(a2ai.codecs)
Expand Down Expand Up @@ -1845,7 +1854,7 @@ def forward(
D = a2ai.embedding_dim
ctx.a2ai = a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
sharded_output_embeddings = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -1952,7 +1961,7 @@ def forward(
def backward(ctx, *grad_output):
a2ai = ctx.a2ai
myreq = ctx.myreq
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_input = myreq.tensor
if a2ai.codecs is not None:
Expand Down Expand Up @@ -1980,7 +1989,7 @@ def forward(
a2ai = myreq.a2ai
ctx.a2ai = a2ai
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
output = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2067,7 +2076,7 @@ def forward(
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_inputs = list(myreq.tensor)
rsi = myreq.rsi
Expand Down Expand Up @@ -2095,7 +2104,7 @@ def forward(
*dummy_tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
output = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2174,7 +2183,7 @@ def forward(
# pyre-fixme[2]: Parameter must be annotated.
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_inputs = myreq.tensor
rsi = myreq.rsi
Expand All @@ -2199,7 +2208,7 @@ def forward(
*dummy_Tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
output = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2270,7 +2279,7 @@ def forward(
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
agi = myreq.agi
grad_input = myreq.tensor
Expand All @@ -2296,7 +2305,7 @@ def forward(
*dummy_tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
outputs = myreq.tensor
myreq.tensor = None
Expand Down Expand Up @@ -2382,7 +2391,7 @@ def forward(
def backward(ctx, *unused: Tensor) -> Tuple[Optional[Tensor], ...]:
myreq = ctx.myreq
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
grad_input = myreq.tensor
rsi = myreq.rsi
Expand All @@ -2407,7 +2416,7 @@ def forward(
*dummy_tensor: Tensor,
) -> Tensor:
assert myreq.req is not None
myreq.req.wait()
wait_req(myreq)
myreq.req = None
# pyre-ignore
output: torch.Tensor = myreq.tensor
Expand Down
Loading