diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 3bb2cf825..6c4b06b7c 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -3,20 +3,72 @@ """ from __future__ import annotations -from typing import Optional, Any, Dict +from abc import abstractmethod, ABC +import logging +import numpy import os import socket -import logging +from typing import Callable, Optional, Any, Dict, Type import torch from torch.nn.parallel import DistributedDataParallel -from returnn.config import Config -from returnn.util.basic import CollectionReadCheckCovered +from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError _logger = logging.getLogger("returnn.torch.distributed") +class ParamSynchronizer(ABC): + """ + Custom parameter synchronization primitive. + + Contains a callback that is called after every train step to synchronize model parameters + across processes/nodes. + """ + + @abstractmethod + def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **kwargs): + """ + `__init__` called after the default global process group is created. + Can be used to initialize any additional custom process (sub)groups. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility. + + :param rank: global rank of the current process across all nodes + :param size: global world size across all nodes + :param local_rank: local rank of the current process on the current node + :param local_rank: local world size on the current node + """ + super().__init__() + + def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel: + """ + Creates an associated `DistributedDataParallel` for the given module for gradient synchronization. + + This function can be left unimplemented if no gradient synchronization is done. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility. + """ + raise OptionalNotImplementedError + + @abstractmethod + def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs): + """ + Parameter synchronization callback called after every train step. + + Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility. + + :param module: the NN being trained + :param train_step_idx: the current train step + :param kwargs: any additional kwargs. + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs): + """forwards to :func:``step``""" + return self.step(*args, **kwargs) + + class DistributedContext: """ This class setups some helper functions for torch distributed training @@ -26,6 +78,9 @@ def __init__(self, options: Dict[str, Any]): import torch.distributed as dist self._opts = CollectionReadCheckCovered(options) + # Only used to generate forwards compatibility ensuring random kwargs, therefore + # the seed is not important + self._rng = numpy.random.default_rng() # when no backend is specified, both gloo and nccl backends will be created # the gloo backend will be used for collectives with CPU tensors and @@ -42,8 +97,11 @@ def __init__(self, options: Dict[str, Any]): % (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size) ) + self._custom_sync_class: Optional[Type[ParamSynchronizer]] = self._opts.get("synchronizer", None) + self._custom_sync: Optional[Callable] = None self._reduce_type = self._opts.get("reduce_type", "grad") self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None) + if self._reduce_type == "param": assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, ( f"reduce_type param: param_sync_step must be a positive int," @@ -52,6 +110,23 @@ def __init__(self, options: Dict[str, Any]): _logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}") elif self._reduce_type == "grad": _logger.info("reduce_type grad") + elif self._reduce_type == "custom": + if issubclass(self._custom_sync_class, ParamSynchronizer): + self._custom_sync = self._custom_sync_class( + rank=self._rank, + size=self._size, + local_rank=self._local_rank, + local_size=self._local_size, + **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}, + ) + elif isinstance(self._custom_sync_class, Callable): + self._custom_sync = self._custom_sync_class + else: + raise ValueError( + f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}" + ) + + _logger.info(f"reduce_type custom: {type(self._custom_sync)}") else: raise ValueError(f"invalid reduce_type {self._reduce_type!r}") @@ -70,6 +145,8 @@ def _check_no_unknown_opts(self): self._opts.get("options") if self._reduce_type == "param": self._opts.get("sync_on_cpu") + if self._reduce_type == "custom": + self._opts.get("synchronizer") self._opts.assert_all_read() @@ -102,7 +179,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis """ if self._reduce_type == "param": return None - assert self._reduce_type == "grad" + assert self._reduce_type in ["custom", "grad"] + + if self._reduce_type == "custom": + assert isinstance(self._custom_sync, (ParamSynchronizer, Callable)) + + if isinstance(self._custom_sync, ParamSynchronizer): + try: + return self._custom_sync.make_distributed_model( + module=module, **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None} + ) + except OptionalNotImplementedError: + pass + else: + # callable short form does not have support for DistributedDataParallel + pass + + return None + cls = self._opts.get("class", DistributedDataParallel) if cls is not DistributedDataParallel: _logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.") @@ -115,7 +209,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" - if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): + if self._reduce_type == "custom": + with torch.no_grad(): # TODO: do we want this for all syncers? + self._custom_sync( + module=module, + train_step_idx=epoch_step_idx, + **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}, + ) + elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) @@ -155,7 +256,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if sync_on_cpu: for param in module.parameters(): - # Separately move each param to CPU (instead of the whole module), to safe CPU memory. + # Separately move each param to CPU (instead of the whole module), to save CPU memory. param_cpu = param.to(torch.device("cpu")) # On CPU, we are likely using Gloo, and Gloo does not support AVG dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM) @@ -166,12 +267,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if dist.get_backend() == "gloo": # Gloo does not support AVG reduce_op = dist.ReduceOp.SUM + elif hasattr(dist.ReduceOp, "AVG"): + reduce_op = dist.ReduceOp.AVG else: - if hasattr(dist.ReduceOp, "AVG"): - reduce_op = dist.ReduceOp.AVG - else: - # Older PyTorch versions do not have ReduceOp.AVG. - reduce_op = dist.ReduceOp.SUM + # Older PyTorch versions do not have ReduceOp.AVG. + reduce_op = dist.ReduceOp.SUM for param in module.parameters(): dist.all_reduce(param.data, op=reduce_op)