diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 805eea0bc..55117aa71 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -14,6 +14,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch +from parameterized import parameterized_class from torch import nn from torchrec.ir.serializer import JsonSerializer @@ -21,6 +22,7 @@ decapsulate_ir_modules, encapsulate_ir_modules, mark_dynamic_kjt, + unlift_and_swap_modules, ) from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -92,7 +94,38 @@ def deserialize_from_dict( return CompoundModule(ebc, comp, mlist) +@parameterized_class( + [ + {"deserialize_with_swap": True}, + {"deserialize_with_swap": False}, + ], + class_name_func=lambda cls, _, params: f"{cls.__name__}{'Swap' if params['deserialize_with_swap'] else ''}", +) class TestJsonSerializer(unittest.TestCase): + def deserialize_model( + self, + ep: torch.export.ExportedProgram, + *, + device: Optional[torch.device] = None, + finalize_interpreter_modules: bool = False, + short_circuit_pytree_ebc_regroup: bool = False, + ) -> torch.nn.Module: + + if self.deserialize_with_swap: # pyre-ignore[16] + deserialized_model = unlift_and_swap_modules(ep, JsonSerializer, device) + + else: + unflatten = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules( + unflatten, + JsonSerializer, + device, + finalize_interpreter_modules=finalize_interpreter_modules, + short_circuit_pytree_ebc_regroup=short_circuit_pytree_ebc_regroup, + ) + + return deserialized_model + # in the model we have 5 duplicated EBCs, 1 fpEBC with fpCollection, and 1 fpEBC with fpDict def generate_model(self) -> nn.Module: class Model(nn.Module): @@ -200,8 +233,7 @@ def test_serialize_deserialize_ebc(self) -> None: self.assertEqual(eager_out[i].shape, tensor.shape) # Deserialize EBC - unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model = self.deserialize_model(ep) # check EBC config for i in range(5): @@ -285,8 +317,7 @@ def test_dynamic_shape_ebc(self) -> None: self.assertEqual(eager_out[i].shape, tensor.shape) # Deserialize EBC - unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model = self.deserialize_model(ep) deserialized_model.load_state_dict(model.state_dict()) # Run forward on deserialized model @@ -339,10 +370,7 @@ def test_deserialized_device(self) -> None: if device == "cuda" and not torch.cuda.is_available(): continue device = torch.device(device) - unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules( - unflatten_ep, JsonSerializer, device - ) + deserialized_model = self.deserialize_model(ep, device=device) for name, m in deserialized_model.named_modules(): if hasattr(m, "device"): assert m.device.type == device.type, f"{name} should be on {device}" @@ -419,8 +447,7 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]: self.assertEqual(x.shape, y.shape) # Deserialize - unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model = self.deserialize_model(ep) # Check if Compound Module is deserialized correctly self.assertIsInstance(deserialized_model.comp, CompoundModule) self.assertIsInstance(deserialized_model.comp.comp, CompoundModule) @@ -514,8 +541,7 @@ def forward( for key in eager_out.keys(): self.assertEqual(ep_output[key].shape, eager_out[key].shape) # Deserialize EBC - unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model = self.deserialize_model(ep) self.assertFalse(deserialized_model.regroup._is_inited) deserialized_out = deserialized_model(id_list_features) self.assertTrue(deserialized_model.regroup._is_inited) @@ -592,10 +618,8 @@ def forward( # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=(tuple(sparse_fqns)), ) - unflatten_ep = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules( - unflatten_ep, - JsonSerializer, + deserialized_model = self.deserialize_model( + ep, short_circuit_pytree_ebc_regroup=True, finalize_interpreter_modules=True, ) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index fb0f137ff..4aa3dfbae 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -18,6 +18,7 @@ from torch import nn from torch.export import Dim, ShapesCollection +from torch.export._swap import _swap_modules from torch.export.dynamic_shapes import _Dim as DIM from torch.export.unflatten import InterpreterModule from torch.fx import Node @@ -168,6 +169,46 @@ def decapsulate_ir_modules( return module +def unlift_and_swap_modules( + ep: torch.export.ExportedProgram, + serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, + device: Optional[torch.device] = None, +) -> torch.fx.GraphModule: + """ + Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps + previously traced modules with new eager modules specified in the + `serializer`. + + Args: + ep (ExportedProgram): Exported program + serializer: TorchRec serializer which will deserialize stored metadata + on the ExportedProgram to initialize new eager TorchRec modules + device: Device to initialize new eager modules on + """ + + gm = ep.module() + + gm.graph.eliminate_dead_code() + module_fqn_to_swap = { + key[: -(len("ir_metadata") + 1)] + for key in ep.constants.keys() + if "ir_metadata" in key + } + + def get_submodule(model: torch.nn.Module, fqn: str) -> torch.nn.Module: + for attr in fqn.split("."): + model = getattr(model, attr) + return model + + modules_to_swap = { + fqn: serializer.decapsulate_module(get_submodule(gm, fqn), device) + for fqn in module_fqn_to_swap + } + gm = _swap_modules(ep, modules_to_swap) + + return gm + + def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM: """ Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name.