Skip to content

Commit

Permalink
2024-09-18 nightly release (90d62cb)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 18, 2024
1 parent 0b21e72 commit 745a0a9
Show file tree
Hide file tree
Showing 15 changed files with 881 additions and 14 deletions.
10 changes: 10 additions & 0 deletions docs/source/_templates/layout.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{% extends "!layout.html" %}

{% block footer %}
{{ super() }}

<script type="text/javascript">
var collapsedSections = ['Introduction', 'All API References']
</script>

{% endblock %}
12 changes: 12 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@
html_theme = "pytorch_sphinx_theme"
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
html_theme_options = {
"pytorch_project": "torchrec",
"display_version": True,
"logo_only": True,
"collapse_navigation": False,
"includehidden": True,
}

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
Expand Down
12 changes: 10 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@ TorchRec API
------------

.. toctree::
:maxdepth: 2
:caption: Contents:
:maxdepth: 1
:caption: Introduction
:hidden:

overview.rst

.. toctree::
:maxdepth: 1
:caption: All API References
:hidden:

torchrec.datasets.rst
torchrec.datasets.scripts.rst
Expand Down
23 changes: 23 additions & 0 deletions docs/source/overview.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. _overview_label:

==================
TorchRec Overview
==================

TorchRec is the PyTorch recommendation system library, designed to provide common primitives
for creating state-of-the-art personalization models and a path to production. TorchRec is
widely adopted in many Meta production recommendation system models for training and inference workflows.

Why TorchRec?
------------------

TorchRec is designed to address the unique challenges of building, scaling and deploying massive,
large-scale recommendation system models, which is not a focus of regular PyTorch. More specifically,
TorchRec provides the following primitives for general recommendation systems:

- **Specialized Components**: TorchRec provides simplistic, specialized modules that are common in authoring recommendation systems, with a focus on embedding tables
- **Advanced Sharding Techniques**: TorchRec provides flexible and customizable methods for sharding massive embedding tables: Row-Wise, Column-Wise, Table-Wise, and so on. TorchRec can automatically determine the best plan for a device topology for efficient training and memory balance
- **Distributed Training**: While PyTorch supports basic distributed training, TorchRec extends these capabilities with more sophisticated model parallelism techniques specifically designed for the massive scale of recommendation systems
- **Incredibly Optimized**: TorchRec training and inference components are incredibly optimized on top of FBGEMM. After all, TorchRec powers some of the largest recommendation system models at Meta
- **Frictionless Path to Deployment**: TorchRec provides simple APIs for transforming a trained model for inference and loading it into a C++ environment for the most optimal inference model
- **Integration with PyTorch Ecosystem**: TorchRec is built on top of PyTorch, meaning it integrates seamlessly with existing PyTorch code, tools, and workflows. This allows developers to leverage their existing knowledge and codebase while utilizing advanced features for recommendation systems. By being a part of the PyTorch ecosystem, TorchRec benefits from the robust community support, continuous updates, and improvements that come with PyTorch.
23 changes: 20 additions & 3 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_key_order_with_ebc_and_regroup(self) -> None:
ebc2.load_state_dict(ebc1.state_dict())
regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"])

class myModel(nn.Module):
class mySparse(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.ebc = ebc
Expand All @@ -569,6 +569,17 @@ def forward(
) -> Dict[str, torch.Tensor]:
return self.regroup([self.ebc(features)])

class myModel(nn.Module):
def __init__(self, ebc, regroup):
super().__init__()
self.sparse = mySparse(ebc, regroup)

def forward(
self,
features: KeyedJaggedTensor,
) -> Dict[str, torch.Tensor]:
return self.sparse(features)

model = myModel(ebc1, regroup)
eager_out = model(id_list_features)

Expand All @@ -582,11 +593,17 @@ def forward(
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
deserialized_model = decapsulate_ir_modules(
unflatten_ep,
JsonSerializer,
short_circuit_pytree_ebc_regroup=True,
finalize_interpreter_modules=True,
)

# we export the model with ebc1 and unflatten the model,
# and then swap with ebc2 (you can think this as the the sharding process
# resulting a shardedEBC), so that we can mimic the key-order change
deserialized_model.ebc = ebc2
deserialized_model.sparse.ebc = ebc2

deserialized_out = deserialized_model(id_list_features)
for key in eager_out.keys():
Expand Down
118 changes: 118 additions & 0 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#!/usr/bin/env python3

import logging
import operator
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type

Expand All @@ -18,7 +19,12 @@
from torch import nn
from torch.export import Dim, ShapesCollection
from torch.export.dynamic_shapes import _Dim as DIM
from torch.export.unflatten import InterpreterModule
from torch.fx import Node
from torchrec.ir.types import SerializerInterface
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
from torchrec.modules.regroup import KTRegroupAsDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


Expand Down Expand Up @@ -129,6 +135,8 @@ def decapsulate_ir_modules(
module: nn.Module,
serializer: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
device: Optional[torch.device] = None,
finalize_interpreter_modules: bool = False,
short_circuit_pytree_ebc_regroup: bool = False,
) -> nn.Module:
"""
Takes a module and decapsulate its embedding modules by retrieving the buffer.
Expand All @@ -147,6 +155,16 @@ def decapsulate_ir_modules(
# we use "ir_metadata" as a convention to identify the deserializable module
if "ir_metadata" in dict(module.named_buffers()):
module = serializer.decapsulate_module(module, device)

if short_circuit_pytree_ebc_regroup:
module = _short_circuit_pytree_ebc_regroup(module)
assert finalize_interpreter_modules, "need finalize_interpreter_modules=True"

if finalize_interpreter_modules:
for mod in module.modules():
if isinstance(mod, InterpreterModule):
mod.finalize()

return module


Expand Down Expand Up @@ -233,3 +251,103 @@ def move_to_copy_nodes_to_device(
nodes.kwargs = new_kwargs

return unflattened_module


def _short_circuit_pytree_ebc_regroup(module: nn.Module) -> nn.Module:
"""
Bypass pytree flatten and unflatten function between EBC and KTRegroupAsDict to avoid key-order issue.
https://fb.workplace.com/groups/1028545332188949/permalink/1042204770823005/
EBC ==> (out-going) pytree.flatten ==> tensors and specs ==> (in-coming) pytree.unflatten ==> KTRegroupAsDict
"""
ebc_fqns: List[str] = []
regroup_fqns: List[str] = []
for fqn, m in module.named_modules():
if isinstance(m, FeatureProcessedEmbeddingBagCollection):
ebc_fqns.append(fqn)
elif isinstance(m, EmbeddingBagCollection):
if len(ebc_fqns) > 0 and fqn.startswith(ebc_fqns[-1]):
continue
ebc_fqns.append(fqn)
elif isinstance(m, KTRegroupAsDict):
regroup_fqns.append(fqn)
if len(ebc_fqns) == len(regroup_fqns) == 0:
# nothing happens if there is no EBC or KTRegroupAsDict (e.g., the PEA case)
return module
elif len(regroup_fqns) == 0:
# model only contains EBCs, KT (from EBC) pytree.flatten has performance impact
logger.warning(
"Expect perf impact if KTRegroupAsDict is not used together with EBCs."
)
return module
elif len(ebc_fqns) == 0:
# model only contains KTRegroupAsDict, KTs are not from EBC, need to be careful
logger.warning("KTRegroupAsDict is not from EBC, need to be careful.")
return module
else:
return prune_pytree_flatten_unflatten(
module, in_fqns=regroup_fqns, out_fqns=ebc_fqns
)


def prune_pytree_flatten_unflatten(
module: nn.Module, in_fqns: List[str], out_fqns: List[str]
) -> nn.Module:
"""
Remove pytree flatten and unflatten function between the given in_fqns and out_fqns.
"preserved module" ==> (out-going) pytree.flatten ==> [tensors and specs]
[tensors and specs] ==> (in-coming) pytree.unflatten ==> "preserved module"
"""

def _get_graph_node(mod: nn.Module, fqn: str) -> Tuple[nn.Module, Node]:
for node in mod.graph.nodes:
if node.op == "call_module" and node.target == fqn:
return mod, node
assert "." in fqn, f"can't find {fqn} in the graph of {mod}"
curr, fqn = fqn.split(".", maxsplit=1)
mod = getattr(mod, curr)
return _get_graph_node(mod, fqn)

# remove tree_unflatten from the in_fqns (in-coming nodes)
for fqn in in_fqns:
submodule, node = _get_graph_node(module, fqn)
assert len(node.args) == 1
getitem_getitem: Node = node.args[0] # pyre-ignore[9]
assert (
getitem_getitem.op == "call_function"
and getitem_getitem.target == operator.getitem
)
tree_unflatten_getitem = node.args[0].args[0] # pyre-ignore[16]
assert (
tree_unflatten_getitem.op == "call_function"
and tree_unflatten_getitem.target == operator.getitem
)
tree_unflatten = tree_unflatten_getitem.args[0]
assert (
tree_unflatten.op == "call_function"
and tree_unflatten.target == torch.utils._pytree.tree_unflatten
)
logger.info(f"Removing tree_unflatten from {fqn}")
input_nodes = tree_unflatten.args[0]
node.args = (input_nodes,)
submodule.graph.eliminate_dead_code()

# remove tree_flatten_spec from the out_fqns (out-going nodes)
for fqn in out_fqns:
submodule, node = _get_graph_node(module, fqn)
users = list(node.users.keys())
assert (
len(users) == 1
and users[0].op == "call_function"
and users[0].target == torch.fx._pytree.tree_flatten_spec
)
tree_flatten_users = list(users[0].users.keys())
assert (
len(tree_flatten_users) == 1
and tree_flatten_users[0].op == "call_function"
and tree_flatten_users[0].target == operator.getitem
)
logger.info(f"Removing tree_flatten_spec from {fqn}")
getitem_node = tree_flatten_users[0]
getitem_node.replace_all_uses_with(node)
submodule.graph.eliminate_dead_code()
return module
11 changes: 11 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
from torchrec.metrics.ctr import CTRMetric
from torchrec.metrics.mae import MAEMetric
from torchrec.metrics.metrics_config import (
BatchSizeStage,
MetricsConfig,
RecMetricEnum,
RecMetricEnumBase,
RecTaskInfo,
StateMetricEnum,
validate_batch_size_stages,
)
from torchrec.metrics.metrics_namespace import (
compose_customized_metric_key,
Expand Down Expand Up @@ -456,16 +458,25 @@ def generate_metric_module(
state_metrics_mapping: Dict[StateMetricEnum, StateMetric],
device: torch.device,
process_group: Optional[dist.ProcessGroup] = None,
batch_size_stages: Optional[List[BatchSizeStage]] = None,
) -> RecMetricModule:
rec_metrics = _generate_rec_metrics(
metrics_config, world_size, my_rank, batch_size, process_group
)
"""
Batch_size_stages currently only used by ThroughputMetric to ensure total_example correct so
different training jobs have aligned mertics.
TODO: update metrics other than ThroughputMetric if it has dependency on batch_size
"""
validate_batch_size_stages(batch_size_stages)

if metrics_config.throughput_metric:
throughput_metric = ThroughputMetric(
batch_size=batch_size,
world_size=world_size,
window_seconds=metrics_config.throughput_metric.window_size,
warmup_steps=metrics_config.throughput_metric.warmup_steps,
batch_size_stages=batch_size_stages,
)
else:
throughput_metric = None
Expand Down
37 changes: 37 additions & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,40 @@ class MetricsConfig:
throughput_metric=None,
state_metrics=[],
)


@dataclass
class BatchSizeStage:
"""
BatchSizeStage class for defining the variable batch size stage.
For a List[BatchSizeStage], the max_iter should be in ascending order, and the last one should have max_iter=None
Attributes
----------
batch_size(int): A multiple of base_batch_size
max_iter(int): The maximum number of iterations for the stage.
When previous BatchSizeStage.max_iters < iter <= max_iters, the stage is effective.
Max_iter is the absolute train iteration count, not the relative count within each stage
"""

batch_size: int = 0
max_iters: Optional[int] = 0


def validate_batch_size_stages(
batch_size_stages: Optional[List[BatchSizeStage]],
) -> None:
if not batch_size_stages:
return

if len(batch_size_stages) == 0:
raise ValueError("Batch size stages should not be empty")

for i in range(len(batch_size_stages) - 1):
if batch_size_stages[i].batch_size >= batch_size_stages[i + 1].batch_size:
raise ValueError(
f"Batch size should be in ascending order. Got {batch_size_stages}"
)
if batch_size_stages[-1].max_iters is not None:
raise ValueError(
f"Batch size stages last stage should have max_iters = None, but get {batch_size_stages[-1].max_iters}"
)
28 changes: 28 additions & 0 deletions torchrec/metrics/tests/test_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import unittest
from unittest.mock import Mock, patch

from torchrec.metrics.metrics_config import BatchSizeStage

from torchrec.metrics.throughput import ThroughputMetric


Expand Down Expand Up @@ -140,3 +142,29 @@ def test_warmup_checkpointing(self) -> None:

# Mimic trainer crashing and loading a checkpoint
throughput_metric._steps = 0

@patch(THROUGHPUT_PATH + ".time.monotonic")
def test_batch_size_schedule(self, time_mock: Mock) -> None:
batch_size_stages = [BatchSizeStage(256, 1), BatchSizeStage(512, None)]
time_mock.return_value = 1
throughput_metric = ThroughputMetric(
batch_size=self.batch_size,
world_size=self.world_size,
window_seconds=100,
batch_size_stages=batch_size_stages,
)

total_examples = 0
throughput_metric.update()
total_examples += batch_size_stages[0].batch_size * self.world_size
self.assertEqual(
throughput_metric.compute(),
{"throughput-throughput|total_examples": total_examples},
)

throughput_metric.update()
total_examples += batch_size_stages[1].batch_size * self.world_size
self.assertEqual(
throughput_metric.compute(),
{"throughput-throughput|total_examples": total_examples},
)
Loading

0 comments on commit 745a0a9

Please sign in to comment.