From 39b7d8020cc616d833da866af6099c4274143cdf Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Wed, 9 Oct 2024 11:20:30 -0700 Subject: [PATCH] Implement unlift_and_swap function Summary: Implement a new swapping API which does the following: 1. Takes an exported program and torchrec serializer 2. Constructs torchrec modules based on serialized metadata stored in the exported program 3. Unlifts the exported program to a normal nn.Module 4. Partitions out nodes and swaps them out with calls to the torchrec modules Differential Revision: D63683421 --- torchrec/ir/utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index fb0f137ff..4aa3dfbae 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -18,6 +18,7 @@ from torch import nn from torch.export import Dim, ShapesCollection +from torch.export._swap import _swap_modules from torch.export.dynamic_shapes import _Dim as DIM from torch.export.unflatten import InterpreterModule from torch.fx import Node @@ -168,6 +169,46 @@ def decapsulate_ir_modules( return module +def unlift_and_swap_modules( + ep: torch.export.ExportedProgram, + serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS, + device: Optional[torch.device] = None, +) -> torch.fx.GraphModule: + """ + Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps + previously traced modules with new eager modules specified in the + `serializer`. + + Args: + ep (ExportedProgram): Exported program + serializer: TorchRec serializer which will deserialize stored metadata + on the ExportedProgram to initialize new eager TorchRec modules + device: Device to initialize new eager modules on + """ + + gm = ep.module() + + gm.graph.eliminate_dead_code() + module_fqn_to_swap = { + key[: -(len("ir_metadata") + 1)] + for key in ep.constants.keys() + if "ir_metadata" in key + } + + def get_submodule(model: torch.nn.Module, fqn: str) -> torch.nn.Module: + for attr in fqn.split("."): + model = getattr(model, attr) + return model + + modules_to_swap = { + fqn: serializer.decapsulate_module(get_submodule(gm, fqn), device) + for fqn in module_fqn_to_swap + } + gm = _swap_modules(ep, modules_to_swap) + + return gm + + def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM: """ Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name.