Skip to content

Commit

Permalink
2024-09-27 nightly release (5904f0d)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 27, 2024
1 parent b934216 commit e80b613
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 14 deletions.
5 changes: 2 additions & 3 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ShardingPlan,
)
from torchrec.distributed.utils import init_parameters
from torchrec.modules.utils import reset_module_states_post_sharding
from torchrec.types import CacheMixin


Expand Down Expand Up @@ -280,8 +281,6 @@ def _replace(_model: nn.Module, path: str = "") -> None:
init_parameters(module, device)
module = module.to(device)

for submod in module.modules():
if isinstance(submod, CacheMixin):
submod.clear_cache()
reset_module_states_post_sharding(module)

return module
9 changes: 7 additions & 2 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
)


# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail
# Therefore we use spawn for HIP runtime until AMD fixes the issue
_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn"


class MultiProcessContext:
def __init__(
self,
Expand Down Expand Up @@ -126,7 +131,7 @@ def _run_multi_process_test(
# pyre-ignore
**kwargs,
) -> None:
ctx = multiprocessing.get_context("forkserver")
ctx = multiprocessing.get_context(_MP_INIT_MODE)
processes = []
for rank in range(world_size):
kwargs["rank"] = rank
Expand All @@ -152,7 +157,7 @@ def _run_multi_process_test_per_rank(
world_size: int,
kwargs_per_rank: List[Dict[str, Any]],
) -> None:
ctx = multiprocessing.get_context("forkserver")
ctx = multiprocessing.get_context(_MP_INIT_MODE)
processes = []
for rank in range(world_size):
kwargs = {}
Expand Down
10 changes: 5 additions & 5 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def setUp(self, backend: str = "nccl") -> None:
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
Expand Down Expand Up @@ -296,7 +296,7 @@ def test_sharding_dp(
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
Expand Down Expand Up @@ -373,7 +373,7 @@ def test_sharding_cw(
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
Expand Down Expand Up @@ -451,7 +451,7 @@ def test_sharding_twcw(
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_sharding_tw(
[
None,
{
"embeddingbags": (torch.optim.SGD, {"lr": 0.01}),
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
},
]
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def sharding_single_rank_test(
optimizer_kwargs,
) in apply_optimizer_in_backward_config.items():
for name, param in global_model_named_params_as_dict.items():
if name not in apply_optim_name:
if apply_optim_name not in name:
continue
assert name in local_model_named_params_as_dict
local_param = local_model_named_params_as_dict[name]
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ class TorchCompileConfig:
Configs for torch.compile
fullgraph: bool = False, whether to compile the whole graph or not
dynamic: bool = False, whether to use dynamic shapes or not
dynamic: Optional[bool] = None, whether to use dynamic shapes or not, if None, automatic_dynamic_shapes will apply
backend: str = "inductor", which compiler to use (either inductor or aot)
compile_on_iter: int = 3, compile the model on which iteration
this is useful when we want to profile the first few iterations of training
and then start using compiled model from iteration #3 onwards
"""

fullgraph: bool = False
dynamic: bool = False
dynamic: Optional[bool] = None
backend: str = "inductor"
compile_on_iter: int = 3

Expand Down
18 changes: 18 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.profiler import record_function
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
from torchrec.streamable import Multistreamable
from torchrec.types import CacheMixin


@dataclass
Expand Down Expand Up @@ -373,3 +374,20 @@ def deterministic_dedup(ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
)

return sorted_unique_ids.view(-1), last_existence_index.flatten()


def reset_module_states_post_sharding(
module: torch.nn.Module,
) -> None:
"""
Reset the module states post sharding.
Involves clearing cached tensors if they exist
from unsharded version.
"""

# Clear Cache for TorchRec modules that have cache. Normally would happen in sharding
# but cached modules might not be part of the TorchRec modules being sharded.
# For example, necessary for KTRegroupAsDict correctness,
for submod in module.modules():
if isinstance(submod, CacheMixin):
submod.clear_cache()
7 changes: 6 additions & 1 deletion torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def _get_kjt_keys(feature: KeyedJaggedTensor) -> List[str]:
return feature.keys()


@torch.fx.wrap
def _cat_embeddings(embeddings: List[Tensor]) -> Tensor:
return embeddings[0] if len(embeddings) == 1 else torch.cat(embeddings, dim=1)


def for_each_module_of_type_do(
module: nn.Module,
module_types: List[Type[torch.nn.Module]],
Expand Down Expand Up @@ -511,7 +516,7 @@ def forward(

return KeyedTensor(
keys=self._embedding_names,
values=torch.cat(embeddings, dim=1),
values=_cat_embeddings(embeddings),
length_per_key=self._length_per_key,
)

Expand Down
113 changes: 113 additions & 0 deletions torchrec/schema/api_tests/test_optimizer_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import inspect
import unittest
from typing import Any, Collection, List, Mapping, Optional, Set, Tuple, Union

import torch
from torch import optim

from torchrec.distributed.types import ShardedTensor
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.schema.utils import is_signature_compatible


class StableKeyedOptimizer(optim.Optimizer):
def __init__(
self,
params: Mapping[str, Union[torch.Tensor, ShardedTensor]],
# pyre-ignore [2]
state: Mapping[Any, Any],
param_groups: Collection[Mapping[str, Any]],
) -> None:
pass

def init_state(
self,
sparse_grad_parameter_names: Optional[Set[str]] = None,
) -> None:
pass

def save_param_groups(self, save: bool) -> None:
pass

# pyre-ignore [2]
def add_param_group(self, param_group: Any) -> None:
pass

def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
pass


class StableCombinedOptimizer(KeyedOptimizer):
def __init__(
self,
optims: List[Union[KeyedOptimizer, Tuple[str, KeyedOptimizer]]],
) -> None:
pass

@property
def optimizers(self) -> List[Tuple[str, StableKeyedOptimizer]]:
return []

@staticmethod
def prepend_opt_key(name: str, opt_key: str) -> str:
return ""

@property
def param_groups(self) -> Collection[Mapping[str, Any]]:
return []

@property
def params(self) -> Mapping[str, Union[torch.Tensor, ShardedTensor]]:
return {}

def post_load_state_dict(self) -> None:
pass

def save_param_groups(self, save: bool) -> None:
pass

# pyre-ignore [2]
def step(self, closure: Any = None) -> None:
pass

def zero_grad(self, set_to_none: bool = False) -> None:
pass


class TestOptimizerSchema(unittest.TestCase):
def test_keyed_optimizer(self) -> None:
stable_keyed_optimizer_funcs = inspect.getmembers(
StableKeyedOptimizer, predicate=inspect.isfunction
)

for func_name, stable_func in stable_keyed_optimizer_funcs:
self.assertTrue(getattr(KeyedOptimizer, func_name, None) is not None)
self.assertTrue(
is_signature_compatible(
inspect.signature(stable_func),
inspect.signature(getattr(KeyedOptimizer, func_name)),
)
)

def test_combined_optimizer(self) -> None:
stable_combined_optimizer_funcs = inspect.getmembers(
StableCombinedOptimizer, predicate=inspect.isfunction
)

for func_name, stable_func in stable_combined_optimizer_funcs:
self.assertTrue(getattr(CombinedOptimizer, func_name, None) is not None)
self.assertTrue(
is_signature_compatible(
inspect.signature(stable_func),
inspect.signature(getattr(CombinedOptimizer, func_name)),
)
)

0 comments on commit e80b613

Please sign in to comment.