Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use util function to generate forwards compat kwargs #1613

Closed
wants to merge 8 commits into from
4 changes: 2 additions & 2 deletions returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from returnn.log import log
from returnn.engine.batch import Batch, BatchSetGenerator
from returnn.datasets.util.vocabulary import Vocabulary
from returnn.util.basic import try_run, NumbersDict, OptionalNotImplementedError
from returnn.util.basic import get_fwd_compat_kwargs, try_run, NumbersDict, OptionalNotImplementedError
from returnn.util import file_cache
from returnn.tensor import TensorDict

Expand Down Expand Up @@ -1050,7 +1050,7 @@ def iterate_seqs(self, recurrent_net=True, used_data_keys=None):
:rtype: list[(int,NumbersDict,NumbersDict)]
"""
if self.custom_chunking_func:
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
sentinel_kw = get_fwd_compat_kwargs()
for seq_idx, t_start, t_end in self.custom_chunking_func(
dataset=self, seq_idx_start=0, recurrent_net=recurrent_net, used_data_keys=used_data_keys, **sentinel_kw
):
Expand Down
9 changes: 3 additions & 6 deletions returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from returnn.datasets.util.vocabulary import Vocabulary
from returnn.tensor import Tensor, TensorDict
from returnn.tensor.dim import Dim
from returnn.util.basic import get_fwd_compat_kwargs
from .basic import init_dataset
from .cached2 import CachedDataset2

Expand Down Expand Up @@ -205,9 +206,7 @@ def _validate_tensor_dict_iter(inner: Iterator[TensorDict]) -> Iterator[TensorDi

data_iter = self._iterate_dataset()
if self._map_seq_stream is not None:
data_iter = self._map_seq_stream(
data_iter, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None}
)
data_iter = self._map_seq_stream(data_iter, rng=self._rng, **get_fwd_compat_kwargs())
assert isinstance(
data_iter, Iterator
), f"map_seq_stream must produce an {Iterator.__name__}, but produced {type(data_iter).__name__}"
Expand All @@ -226,9 +225,7 @@ def _iterate_dataset(self) -> Iterator[TensorDict]:
for data_key in data_keys:
tensor_dict.data[data_key].raw_tensor = self._dataset.get_data(seq_index, data_key)
if self._map_seq is not None:
tensor_dict = self._map_seq(
tensor_dict, rng=self._rng, **{f"fwd_compatible_random_kwarg_{self._rng.randint(0, 1000)}": None}
)
tensor_dict = self._map_seq(tensor_dict, rng=self._rng, **get_fwd_compat_kwargs())
assert isinstance(
tensor_dict, TensorDict
), f"map_seq must produce a {TensorDict.__name__}, but produced {type(tensor_dict).__name__}"
Expand Down
128 changes: 115 additions & 13 deletions returnn/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,77 @@
"""

from __future__ import annotations
from typing import Optional, Any, Dict
from abc import abstractmethod, ABC
import logging
import os
import socket
import logging
from typing import Callable, Optional, Any, Dict, Type, Union

import torch
from torch.nn.parallel import DistributedDataParallel

from returnn.config import Config
from returnn.util.basic import CollectionReadCheckCovered
from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError, get_fwd_compat_kwargs

_logger = logging.getLogger("returnn.torch.distributed")


class ParamSynchronizer(ABC):
"""
Custom parameter synchronization primitive.

Contains a callback that is called after every train step to synchronize model parameters
across processes/nodes.
"""

@abstractmethod
def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **_kwargs):
"""
`__init__` called after the default global process group is created.
Can be used to initialize any additional custom process (sub)groups.

Note the `__init__` is passed a randomly named kwarg on every invocation to ensure forwards compatibility.

:param rank: global rank of the current process across all nodes
:param size: global world size across all nodes
:param local_rank: local rank of the current process on the current node
:param local_rank: local world size on the current node
:param _kwargs: any additional kwargs
"""
super().__init__()

self.rank = rank
self.size = size
self.local_rank = local_rank
self.local_size = local_size

def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
"""
Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.

This function can be left unimplemented if no gradient synchronization is done.

Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.
"""
raise OptionalNotImplementedError

@abstractmethod
def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs):
"""
Parameter synchronization callback called after every train step with updated model parameters.

Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatibility.

:param module: the NN being trained
:param train_step_idx: the current train step
:param kwargs: any additional kwargs
"""
raise NotImplementedError

def __call__(self, *args, **kwargs):
"""forwards to :func:``step``"""
return self.step(*args, **kwargs)


class DistributedContext:
"""
This class setups some helper functions for torch distributed training
Expand All @@ -42,8 +99,13 @@ def __init__(self, options: Dict[str, Any]):
% (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size)
)

self._custom_sync_class: Optional[Union[Callable, Type[ParamSynchronizer]]] = self._opts.get(
"synchronizer", None
)
self._custom_sync: Optional[Callable] = None
self._reduce_type = self._opts.get("reduce_type", "grad")
self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None)

if self._reduce_type == "param":
assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, (
f"reduce_type param: param_sync_step must be a positive int,"
Expand All @@ -52,6 +114,23 @@ def __init__(self, options: Dict[str, Any]):
_logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}")
elif self._reduce_type == "grad":
_logger.info("reduce_type grad")
elif self._reduce_type == "custom":
if issubclass(self._custom_sync_class, ParamSynchronizer):
self._custom_sync = self._custom_sync_class(
rank=self._rank,
size=self._size,
local_rank=self._local_rank,
local_size=self._local_size,
**get_fwd_compat_kwargs(),
)
elif isinstance(self._custom_sync_class, Callable):
self._custom_sync = self._custom_sync_class
else:
raise ValueError(
f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}"
)

_logger.info(f"reduce_type custom: {type(self._custom_sync)}")
else:
raise ValueError(f"invalid reduce_type {self._reduce_type!r}")

Expand All @@ -70,6 +149,8 @@ def _check_no_unknown_opts(self):
self._opts.get("options")
if self._reduce_type == "param":
self._opts.get("sync_on_cpu")
if self._reduce_type == "custom":
self._opts.get("synchronizer")

self._opts.assert_all_read()

Expand Down Expand Up @@ -102,7 +183,22 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
"""
if self._reduce_type == "param":
return None
assert self._reduce_type == "grad"
assert self._reduce_type in ["custom", "grad"]

if self._reduce_type == "custom":
assert isinstance(self._custom_sync, (ParamSynchronizer, Callable))

if isinstance(self._custom_sync, ParamSynchronizer):
try:
return self._custom_sync.make_distributed_model(module=module, **get_fwd_compat_kwargs())
except OptionalNotImplementedError:
pass
else:
# callable short form does not have support for DistributedDataParallel
pass

return None

cls = self._opts.get("class", DistributedDataParallel)
if cls is not DistributedDataParallel:
_logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.")
Expand All @@ -115,7 +211,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis

def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int):
"""one train step"""
if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
if self._reduce_type == "custom":
with torch.no_grad(): # TODO: do we want this for all syncers?
self._custom_sync(
module=module,
train_step_idx=epoch_step_idx,
**get_fwd_compat_kwargs(),
)
elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
_sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False))


Expand All @@ -127,7 +230,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]:
"""
:param Config|None config:
:returns: the global context if Torch distributed is enabled, or None otherwise.
If we did not setup the context yet, it will automatically create it.
If we did not set up the context yet, it will automatically create it.
"""
global _is_set_up, _ctx
if _is_set_up:
Expand Down Expand Up @@ -155,7 +258,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):

if sync_on_cpu:
for param in module.parameters():
# Separately move each param to CPU (instead of the whole module), to safe CPU memory.
# Separately move each param to CPU (instead of the whole module), to save CPU memory.
param_cpu = param.to(torch.device("cpu"))
# On CPU, we are likely using Gloo, and Gloo does not support AVG
dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM)
Expand All @@ -166,12 +269,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
if dist.get_backend() == "gloo":
# Gloo does not support AVG
reduce_op = dist.ReduceOp.SUM
elif hasattr(dist.ReduceOp, "AVG"):
reduce_op = dist.ReduceOp.AVG
else:
if hasattr(dist.ReduceOp, "AVG"):
reduce_op = dist.ReduceOp.AVG
else:
# Older PyTorch versions do not have ReduceOp.AVG.
reduce_op = dist.ReduceOp.SUM
# Older PyTorch versions do not have ReduceOp.AVG.
reduce_op = dist.ReduceOp.SUM

for param in module.parameters():
dist.all_reduce(param.data, op=reduce_op)
Expand Down
7 changes: 3 additions & 4 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch.utils.data import DataLoader
from torch import autocast
from torch.cuda import amp
from random import random
import math

import returnn
Expand Down Expand Up @@ -680,7 +679,7 @@ def _run_step(
if self._use_autocast
else nullcontext()
), rf.set_default_device_ctx(self._device):
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
sentinel_kw = util.get_fwd_compat_kwargs()
if train_func:
self._train_step_func(model=self._orig_model, extern_data=extern_data, **sentinel_kw)
else:
Expand Down Expand Up @@ -846,7 +845,7 @@ def _load_model(self):
if self._use_autocast
else nullcontext()
), rf.set_default_device_ctx(self._device):
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
sentinel_kw = util.get_fwd_compat_kwargs()
for hook in load_model_post_hooks:
hook(model=self._orig_model, **sentinel_kw)

Expand Down Expand Up @@ -876,7 +875,7 @@ def _create_model(self, *, epoch: int, step: int):

get_model_func = self.config.typed_value("get_model")
assert get_model_func, "get_model not defined in config"
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
sentinel_kw = util.get_fwd_compat_kwargs()
model = get_model_func(epoch=epoch, step=step, **sentinel_kw)
self._orig_model = model
if isinstance(model, rf.Module):
Expand Down
12 changes: 12 additions & 0 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4586,3 +4586,15 @@ def override_env_var(var_name: str, value: str):
os.environ[var_name] = cur_val
else:
os.environ.pop(var_name)


_fwd_compat_rng = np.random.default_rng()


def get_fwd_compat_kwargs() -> Dict[str, Any]:
"""
Returns a dictionary suitable for passing as kwargs for any RETURNN userland
function where forwards compatibility wrt. additional arguments must be
ensured.
"""
return {f"fwd_compatible_random_kwarg_{_fwd_compat_rng.integers(0, 100)}": None}
3 changes: 1 addition & 2 deletions tools/torch_export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from typing import Callable, Optional, Dict, List
import argparse
import os
from random import random

import _setup_returnn_env # noqa
from returnn.config import Config
Expand Down Expand Up @@ -204,7 +203,7 @@ def main():

get_model_func = config.typed_value("get_model")
assert get_model_func, "get_model() isn't specified in the config passed as a parameter."
sentinel_kw = {"__fwd_compatible_random_arg_%i" % int(random() * 100): None}
sentinel_kw = util.get_fwd_compat_kwargs()
model = get_model_func(epoch=epoch, step=step, **sentinel_kw)

is_rf_module = isinstance(model, rf.Module)
Expand Down
Loading