Skip to content

Commit

Permalink
torch distributed: add support for user-specified parameter synchroni…
Browse files Browse the repository at this point in the history
…zation
  • Loading branch information
NeoLegends committed Sep 4, 2024
1 parent eb0f22e commit 5e029dd
Showing 1 changed file with 112 additions and 12 deletions.
124 changes: 112 additions & 12 deletions returnn/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,72 @@
"""

from __future__ import annotations
from typing import Optional, Any, Dict
from abc import abstractmethod, ABC
import logging
import numpy
import os
import socket
import logging
from typing import Callable, Optional, Any, Dict, Type

import torch
from torch.nn.parallel import DistributedDataParallel

from returnn.config import Config
from returnn.util.basic import CollectionReadCheckCovered
from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError

_logger = logging.getLogger("returnn.torch.distributed")


class ParamSynchronizer(ABC):
"""
Custom parameter synchronization primitive.
Contains a callback that is called after every train step to synchronize model parameters
across processes/nodes.
"""

@abstractmethod
def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **kwargs):
"""
`__init__` called after the default global process group is created.
Can be used to initialize any additional custom process (sub)groups.
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
:param rank: global rank of the current process across all nodes
:param size: global world size across all nodes
:param local_rank: local rank of the current process on the current node
:param local_rank: local world size on the current node
"""
super().__init__()

def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
"""
Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.
This function can be left unimplemented if no gradient synchronization is done.
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
"""
raise OptionalNotImplementedError

@abstractmethod
def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs):
"""
Parameter synchronization callback called after every train step.
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
:param module: the NN being trained
:param train_step_idx: the current train step
:param kwargs: any additional kwargs.
"""
raise NotImplementedError

def __call__(self, *args, **kwargs):
"""forwards to :func:``step``"""
return self.step(*args, **kwargs)


class DistributedContext:
"""
This class setups some helper functions for torch distributed training
Expand All @@ -26,6 +78,9 @@ def __init__(self, options: Dict[str, Any]):
import torch.distributed as dist

self._opts = CollectionReadCheckCovered(options)
# Only used to generate forwards compatibility ensuring random kwargs, therefore
# the seed is not important
self._rng = numpy.random.default_rng()

# when no backend is specified, both gloo and nccl backends will be created
# the gloo backend will be used for collectives with CPU tensors and
Expand All @@ -42,8 +97,11 @@ def __init__(self, options: Dict[str, Any]):
% (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size)
)

self._custom_sync_class: Optional[Type[ParamSynchronizer]] = self._opts.get("synchronizer", None)
self._custom_sync: Optional[Callable] = None
self._reduce_type = self._opts.get("reduce_type", "grad")
self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None)

if self._reduce_type == "param":
assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, (
f"reduce_type param: param_sync_step must be a positive int,"
Expand All @@ -52,6 +110,23 @@ def __init__(self, options: Dict[str, Any]):
_logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}")
elif self._reduce_type == "grad":
_logger.info("reduce_type grad")
elif self._reduce_type == "custom":
if issubclass(self._custom_sync_class, ParamSynchronizer):
self._custom_sync = self._custom_sync_class(
rank=self._rank,
size=self._size,
local_rank=self._local_rank,
local_size=self._local_size,
**{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None},
)
elif isinstance(self._custom_sync_class, Callable):
self._custom_sync = self._custom_sync_class
else:
raise ValueError(
f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}"
)

_logger.info(f"reduce_type custom: {type(self._custom_sync)}")
else:
raise ValueError(f"invalid reduce_type {self._reduce_type!r}")

Expand All @@ -70,6 +145,8 @@ def _check_no_unknown_opts(self):
self._opts.get("options")
if self._reduce_type == "param":
self._opts.get("sync_on_cpu")
if self._reduce_type == "custom":
self._opts.get("synchronizer")

self._opts.assert_all_read()

Expand Down Expand Up @@ -102,7 +179,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
"""
if self._reduce_type == "param":
return None
assert self._reduce_type == "grad"
assert self._reduce_type in ["custom", "grad"]

if self._reduce_type == "custom":
assert isinstance(self._custom_sync, (ParamSynchronizer, Callable))

if isinstance(self._custom_sync, ParamSynchronizer):
try:
return self._custom_sync.make_distributed_model(
module=module, **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}
)
except OptionalNotImplementedError:
pass
else:
# callable short form does not have support for DistributedDataParallel
pass

return None

cls = self._opts.get("class", DistributedDataParallel)
if cls is not DistributedDataParallel:
_logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.")
Expand All @@ -115,7 +209,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis

def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int):
"""one train step"""
if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
if self._reduce_type == "custom":
with torch.no_grad(): # TODO: do we want this for all syncers?
self._custom_sync(
module=module,
train_step_idx=epoch_step_idx,
**{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None},
)
elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
_sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False))


Expand Down Expand Up @@ -155,7 +256,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):

if sync_on_cpu:
for param in module.parameters():
# Separately move each param to CPU (instead of the whole module), to safe CPU memory.
# Separately move each param to CPU (instead of the whole module), to save CPU memory.
param_cpu = param.to(torch.device("cpu"))
# On CPU, we are likely using Gloo, and Gloo does not support AVG
dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM)
Expand All @@ -166,12 +267,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
if dist.get_backend() == "gloo":
# Gloo does not support AVG
reduce_op = dist.ReduceOp.SUM
elif hasattr(dist.ReduceOp, "AVG"):
reduce_op = dist.ReduceOp.AVG
else:
if hasattr(dist.ReduceOp, "AVG"):
reduce_op = dist.ReduceOp.AVG
else:
# Older PyTorch versions do not have ReduceOp.AVG.
reduce_op = dist.ReduceOp.SUM
# Older PyTorch versions do not have ReduceOp.AVG.
reduce_op = dist.ReduceOp.SUM

for param in module.parameters():
dist.all_reduce(param.data, op=reduce_op)
Expand Down

0 comments on commit 5e029dd

Please sign in to comment.