Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unlift_and_swap function #2477

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from parameterized import parameterized_class
from torch import nn
from torchrec.ir.serializer import JsonSerializer

from torchrec.ir.utils import (
decapsulate_ir_modules,
encapsulate_ir_modules,
mark_dynamic_kjt,
unlift_and_swap_modules,
)

from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -92,7 +94,38 @@ def deserialize_from_dict(
return CompoundModule(ebc, comp, mlist)


@parameterized_class(
[
{"deserialize_with_swap": True},
{"deserialize_with_swap": False},
],
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Swap' if params['deserialize_with_swap'] else ''}",
)
class TestJsonSerializer(unittest.TestCase):
def deserialize_model(
self,
ep: torch.export.ExportedProgram,
*,
device: Optional[torch.device] = None,
finalize_interpreter_modules: bool = False,
short_circuit_pytree_ebc_regroup: bool = False,
) -> torch.nn.Module:

if self.deserialize_with_swap: # pyre-ignore[16]
deserialized_model = unlift_and_swap_modules(ep, JsonSerializer, device)

else:
unflatten = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(
unflatten,
JsonSerializer,
device,
finalize_interpreter_modules=finalize_interpreter_modules,
short_circuit_pytree_ebc_regroup=short_circuit_pytree_ebc_regroup,
)

return deserialized_model

# in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict
def generate_model(self) -> nn.Module:
class Model(nn.Module):
Expand Down Expand Up @@ -200,8 +233,7 @@ def test_serialize_deserialize_ebc(self) -> None:
self.assertEqual(eager_out[i].shape, tensor.shape)

# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = self.deserialize_model(ep)

# check EBC config
for i in range(5):
Expand Down Expand Up @@ -285,8 +317,7 @@ def test_dynamic_shape_ebc(self) -> None:
self.assertEqual(eager_out[i].shape, tensor.shape)

# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = self.deserialize_model(ep)
deserialized_model.load_state_dict(model.state_dict())

# Run forward on deserialized model
Expand Down Expand Up @@ -339,10 +370,7 @@ def test_deserialized_device(self) -> None:
if device == "cuda" and not torch.cuda.is_available():
continue
device = torch.device(device)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(
unflatten_ep, JsonSerializer, device
)
deserialized_model = self.deserialize_model(ep, device=device)
for name, m in deserialized_model.named_modules():
if hasattr(m, "device"):
assert m.device.type == device.type, f"{name} should be on {device}"
Expand Down Expand Up @@ -419,8 +447,7 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
self.assertEqual(x.shape, y.shape)

# Deserialize
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = self.deserialize_model(ep)
# Check if Compound Module is deserialized correctly
self.assertIsInstance(deserialized_model.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.comp, CompoundModule)
Expand Down Expand Up @@ -514,8 +541,7 @@ def forward(
for key in eager_out.keys():
self.assertEqual(ep_output[key].shape, eager_out[key].shape)
# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = self.deserialize_model(ep)
self.assertFalse(deserialized_model.regroup._is_inited)
deserialized_out = deserialized_model(id_list_features)
self.assertTrue(deserialized_model.regroup._is_inited)
Expand Down Expand Up @@ -592,10 +618,8 @@ def forward(
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(
unflatten_ep,
JsonSerializer,
deserialized_model = self.deserialize_model(
ep,
short_circuit_pytree_ebc_regroup=True,
finalize_interpreter_modules=True,
)
Expand Down
41 changes: 41 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torch import nn
from torch.export import Dim, ShapesCollection
from torch.export._swap import _swap_modules
from torch.export.dynamic_shapes import _Dim as DIM
from torch.export.unflatten import InterpreterModule
from torch.fx import Node
Expand Down Expand Up @@ -168,6 +169,46 @@ def decapsulate_ir_modules(
return module


def unlift_and_swap_modules(
ep: torch.export.ExportedProgram,
serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
device: Optional[torch.device] = None,
) -> torch.fx.GraphModule:
"""
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
previously traced modules with new eager modules specified in the
`serializer`.

Args:
ep (ExportedProgram): Exported program
serializer: TorchRec serializer which will deserialize stored metadata
on the ExportedProgram to initialize new eager TorchRec modules
device: Device to initialize new eager modules on
"""

gm = ep.module()

gm.graph.eliminate_dead_code()
module_fqn_to_swap = {
key[: -(len("ir_metadata") + 1)]
for key in ep.constants.keys()
if "ir_metadata" in key
}

def get_submodule(model: torch.nn.Module, fqn: str) -> torch.nn.Module:
for attr in fqn.split("."):
model = getattr(model, attr)
return model

modules_to_swap = {
fqn: serializer.decapsulate_module(get_submodule(gm, fqn), device)
for fqn in module_fqn_to_swap
}
gm = _swap_modules(ep, modules_to_swap)

return gm


def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM:
"""
Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name.
Expand Down
Loading