Skip to content

Commit

Permalink
add forward/backward test for _fbgemm_permute_pooled_embs (#2480)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2480

# context
* S443491 is caused by using a customized [version](https://www.internalfb.com/code/fbsource/[552b2a3cb49a261daa48b68b3647e8a951a3aa1b]/fbcode/minimal_viable_ai/models/main_feed_mtml/pytorch_modules.py?lines=2610) (fb.permute_pooled_embs_auto_grad) of fbgemm.permute_pooled_embs_auto_grad
* the fb version doesn't dispatch to autograd but relied on a bug in fbgemm.permute_pooled_embs_auto_grad, which was fixed by D48574563
* The SEV was mitigated by switching to fbgemm version: D62040883
* this diff is to add more tests regarding fbgemm.permute_pooled_embs_auto_grad

# details
* `permute_pooled_embs_auto_grad` is called in `_fbgemm_permute_pooled_embs` function
* add forward and backward test for `_fbgemm_permute_pooled_embs` function

Reviewed By: ge0405

Differential Revision: D64195848

fbshipit-source-id: 237ad75028eb9583bb02a2f305defb083f0f280d
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Oct 11, 2024
1 parent be7cadb commit b6e784e
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
_fbgemm_permute_pooled_embs,
_kt_regroup_arguments,
_regroup_keyed_tensors,
ComputeJTDictToKJT,
Expand Down Expand Up @@ -2342,6 +2343,7 @@ def test_regroup_multiple_kt(self) -> None:
KeyedTensor.regroup,
regroup_kts,
permute_multi_embedding,
_fbgemm_permute_pooled_embs,
],
device_str=["cpu", "cuda", "meta"],
)
Expand Down Expand Up @@ -2376,6 +2378,7 @@ def test_regroup_kts(
KeyedTensor.regroup,
regroup_kts,
permute_multi_embedding,
_fbgemm_permute_pooled_embs,
],
device_str=["cpu", "cuda", "meta"],
)
Expand Down Expand Up @@ -2446,18 +2449,33 @@ def test_regroup_backward_skips_and_duplicates(self) -> None:
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

def test_regroup_backward(self) -> None:
@repeat_test(
regroup_func=[
KeyedTensor.regroup,
regroup_kts,
permute_multi_embedding,
_fbgemm_permute_pooled_embs,
],
device_str=["cpu", "cuda"],
)
def test_regroup_backward(
self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str
) -> None:
if device_str == "cuda" and not torch.cuda.is_available():
return
else:
device = torch.device(device_str)
kts = build_kts(
dense_features=20,
sparse_features=20,
dim_dense=64,
dim_sparse=128,
batch_size=128,
device=torch.device("cpu"),
device=device,
run_backward=True,
)
groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False)
labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float()
labels = torch.randint(0, 1, (128,), device=device).float()

tensor_groups = KeyedTensor.regroup(kts, groups)
pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1))
Expand All @@ -2473,7 +2491,7 @@ def test_regroup_backward(self) -> None:
kts[0].values().grad = None
kts[1].values().grad = None

tensor_groups = _regroup_keyed_tensors(kts, groups)
tensor_groups = regroup_func(kts, groups)
pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1))
loss = torch.nn.functional.l1_loss(pred1, labels).sum()
expected_kt_0_grad = torch.autograd.grad(
Expand Down

0 comments on commit b6e784e

Please sign in to comment.