diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 000b5c9a685..6c1357203c6 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -1,9 +1,11 @@ +import itertools import socket from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union import torch +from torch import distributed as dist from ignite.distributed.comp_models import ( _SerialModel, @@ -43,6 +45,7 @@ "one_rank_only", "new_group", "one_rank_first", + "all_gather_tensors_with_shapes", ] _model = _SerialModel() @@ -350,6 +353,36 @@ def all_reduce( return _model.all_reduce(tensor, op, group=group) +def all_gather_tensors_with_shapes( + tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None +) -> List[torch.Tensor]: + """Gather tensors with different shapes but with the same number of dimensions from across processes.""" + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + + if isinstance(group, list) and all(isinstance(item, int) for item in group): + group = _model.new_group(group) + + if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER: + return [tensor] + + max_shape = torch.tensor(shapes).amax(dim=0) + padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist() + padded_tensor = torch.nn.functional.pad( + tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes)))) + ) + all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0) + return [ + all_padded_tensors[ + [ + slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size) + for dim, dim_size in enumerate(shape) + ] + ] + for rank, shape in enumerate(shapes) + ] + + def all_gather( tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None ) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]: diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 7845f0cd1ce..60c14dd9720 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -3,7 +3,7 @@ import torch.distributed as dist import ignite.distributed as idist -from ignite.distributed.utils import sync +from ignite.distributed.utils import all_gather_tensors_with_shapes, sync from ignite.engine import Engine, Events @@ -291,6 +291,56 @@ def _test_distrib_all_gather_group(device): res = idist.all_gather(t, group="abc") +def _test_idist_all_gather_tensors_with_different_shapes(device): + rank = idist.get_rank() + ws = idist.get_world_size() + reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + rank_tensor = reference[ + rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, + rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, + rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + ] + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)]) + for r in range(ws): + r_tensor = reference[ + r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, + r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, + r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + ] + assert (r_tensor == tensors[r]).all() + + +def _test_idist_all_gather_tensors_with_different_shapes_group(device): + rank = idist.get_rank() + ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1] + ws = idist.get_world_size() + bnd = idist.backend() + if rank in ranks: + reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device) + rank_tensor = reference[ + rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1, + rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2, + rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3, + ] + else: + rank_tensor = torch.tensor([rank], device=device) + if bnd in ("horovod"): + with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"): + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)], ranks) + else: + tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)], ranks) + for r in range(ws): + if r in ranks: + r_tensor = reference[ + r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1, + r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2, + r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3, + ] + assert (r_tensor == tensors[r]).all() + else: + assert tensors == rank_tensor + + def _test_distrib_broadcast(device): rank = idist.get_rank() ws = idist.get_world_size() diff --git a/tests/ignite/distributed/utils/test_horovod.py b/tests/ignite/distributed/utils/test_horovod.py index ead6ed4c330..ecf0d2fc3c6 100644 --- a/tests/ignite/distributed/utils/test_horovod.py +++ b/tests/ignite/distributed/utils/test_horovod.py @@ -17,6 +17,8 @@ _test_distrib_new_group, _test_distrib_one_rank_only, _test_distrib_one_rank_only_with_engine, + _test_idist_all_gather_tensors_with_different_shapes, + _test_idist_all_gather_tensors_with_different_shapes_group, _test_sync, ) @@ -163,6 +165,8 @@ def test_idist_all_gather_hvd(gloo_hvd_executor): np = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True) gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True) + gloo_hvd_executor(_test_idist_all_gather_tensors_with_different_shapes, (device,), np=np, do_init=True) + gloo_hvd_executor(_test_idist_all_gather_tensors_with_different_shapes_group, (device,), np=np, do_init=True) @pytest.mark.distributed diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index fda3e1126cc..ac644f97752 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -19,6 +19,8 @@ _test_distrib_new_group, _test_distrib_one_rank_only, _test_distrib_one_rank_only_with_engine, + _test_idist_all_gather_tensors_with_different_shapes, + _test_idist_all_gather_tensors_with_different_shapes_group, _test_sync, ) @@ -253,6 +255,23 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo): _test_distrib_all_gather_group(device) +@pytest.mark.distributed +@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_idist_all_gather_tensors_with_different_shapes_nccl(distributed_context_single_node_nccl): + device = idist.device() + _test_idist_all_gather_tensors_with_different_shapes(device) + _test_idist_all_gather_tensors_with_different_shapes_group(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") +def test_idist_all_gather_tensors_with_different_shapes_gloo(distributed_context_single_node_gloo): + device = idist.device() + _test_idist_all_gather_tensors_with_different_shapes(device) + _test_idist_all_gather_tensors_with_different_shapes_group(device) + + @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index df2d6742b54..eb63cfbf3d6 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -10,6 +10,7 @@ _test_distrib_barrier, _test_distrib_broadcast, _test_distrib_new_group, + _test_idist_all_gather_tensors_with_different_shapes, _test_sync, ) @@ -70,6 +71,7 @@ def test_idist__model_methods_no_dist(): def test_idist_collective_ops_no_dist(): _test_distrib_all_reduce("cpu") _test_distrib_all_gather("cpu") + _test_idist_all_gather_tensors_with_different_shapes("cpu") _test_distrib_barrier("cpu") _test_distrib_broadcast("cpu") _test_distrib_new_group("cpu") @@ -77,6 +79,7 @@ def test_idist_collective_ops_no_dist(): if torch.cuda.device_count() > 1: _test_distrib_all_reduce("cuda") _test_distrib_all_gather("cuda") + _test_idist_all_gather_tensors_with_different_shapes("cuda") _test_distrib_barrier("cuda") _test_distrib_broadcast("cuda") _test_distrib_new_group("cuda") diff --git a/tests/ignite/distributed/utils/test_xla.py b/tests/ignite/distributed/utils/test_xla.py index bb109eacdea..de838ce59c9 100644 --- a/tests/ignite/distributed/utils/test_xla.py +++ b/tests/ignite/distributed/utils/test_xla.py @@ -15,6 +15,8 @@ _test_distrib_new_group, _test_distrib_one_rank_only, _test_distrib_one_rank_only_with_engine, + _test_idist_all_gather_tensors_with_different_shapes, + _test_idist_all_gather_tensors_with_different_shapes_group, _test_sync, ) @@ -151,6 +153,8 @@ def test_idist_all_gather_xla(): device = idist.device() _test_distrib_all_gather(device) _test_distrib_all_gather_group(device) + _test_idist_all_gather_tensors_with_different_shapes(device) + _test_idist_all_gather_tensors_with_different_shapes_group(device) def _test_idist_all_gather_xla_in_child_proc(index):