Skip to content

Commit

Permalink
Implement unlift_and_swap function
Browse files Browse the repository at this point in the history
Summary:
Implement a new swapping API which does the following:
1. Takes an exported program and torchrec serializer
2. Constructs torchrec modules based on serialized metadata stored in the exported program
3. Unlifts the exported program to a normal nn.Module
4. Partitions out nodes and swaps them out with calls to the torchrec modules

Differential Revision: D63683421
  • Loading branch information
angelayi authored and facebook-github-bot committed Oct 9, 2024
1 parent 496b1ac commit 39b7d80
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torch import nn
from torch.export import Dim, ShapesCollection
from torch.export._swap import _swap_modules
from torch.export.dynamic_shapes import _Dim as DIM
from torch.export.unflatten import InterpreterModule
from torch.fx import Node
Expand Down Expand Up @@ -168,6 +169,46 @@ def decapsulate_ir_modules(
return module


def unlift_and_swap_modules(
ep: torch.export.ExportedProgram,
serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
device: Optional[torch.device] = None,
) -> torch.fx.GraphModule:
"""
Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps
previously traced modules with new eager modules specified in the
`serializer`.
Args:
ep (ExportedProgram): Exported program
serializer: TorchRec serializer which will deserialize stored metadata
on the ExportedProgram to initialize new eager TorchRec modules
device: Device to initialize new eager modules on
"""

gm = ep.module()

gm.graph.eliminate_dead_code()
module_fqn_to_swap = {
key[: -(len("ir_metadata") + 1)]
for key in ep.constants.keys()
if "ir_metadata" in key
}

def get_submodule(model: torch.nn.Module, fqn: str) -> torch.nn.Module:
for attr in fqn.split("."):
model = getattr(model, attr)
return model

modules_to_swap = {
fqn: serializer.decapsulate_module(get_submodule(gm, fqn), device)
for fqn in module_fqn_to_swap
}
gm = _swap_modules(ep, modules_to_swap)

return gm


def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM:
"""
Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name.
Expand Down

0 comments on commit 39b7d80

Please sign in to comment.