Skip to content

Commit

Permalink
Helper function all_gather_tensors_with_shapes() (#3281)
Browse files Browse the repository at this point in the history
* Implement the feature with tests

* Remove comment

* Improve docstring and fig a bug in tests

* Fix tests

* Update ignite/distributed/utils.py

Co-authored-by: vfdev <[email protected]>

* Improve docstring

* Fix test and docstring

---------

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
sadra-barikbin and vfdev-5 authored Sep 5, 2024
1 parent e2f9ac0 commit 680ac7f
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 2 deletions.
59 changes: 58 additions & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import itertools
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import distributed as dist

from ignite.distributed.comp_models import (
_SerialModel,
Expand Down Expand Up @@ -43,6 +45,7 @@
"one_rank_only",
"new_group",
"one_rank_first",
"all_gather_tensors_with_shapes",
]

_model = _SerialModel()
Expand Down Expand Up @@ -350,6 +353,60 @@ def all_reduce(
return _model.all_reduce(tensor, op, group=group)


def all_gather_tensors_with_shapes(
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
"""Helper method to gather tensors of possibly different shapes but with the same number of dimensions
across processes.
This function gets the shapes of participating tensors as input so you should know them beforehand. If your
tensors are of different number of dimensions or you don't know their shapes beforehand, you can use
``torch.distributed.all_gather_object``, otherwise this method is quite faster.
Examples:
.. code-block:: python
import ignite.distributed as idist
rank = idist.get_rank()
ws = idist.get_world_size()
tensor = torch.randn(rank+1, rank+2)
tensors = idist.all_gather_tensors_with_shapes(tensor, [[r+1, r+2] for r in range(ws)])
Args:
tensor: tensor to collect across participating processes.
shapes: A sequence containing the shape of participating processes' ``tensor`` s.
group: list of integer or the process group for each backend. If None, the default process group will be used.
Returns:
List[torch.Tensor]
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=0)
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group)
return [
all_padded_tensors[
[
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
for dim, dim_size in enumerate(shape)
]
]
for rank, shape in enumerate(shapes)
]


def all_gather(
tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]:
Expand Down
56 changes: 55 additions & 1 deletion tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

import ignite.distributed as idist
from ignite.distributed.utils import sync
from ignite.distributed.utils import all_gather_tensors_with_shapes, sync
from ignite.engine import Engine, Events


Expand Down Expand Up @@ -291,6 +291,60 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group="abc")


def _test_idist_all_gather_tensors_with_shapes(device):
torch.manual_seed(41)
rank = idist.get_rank()
ws = idist.get_world_size()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)])
for r in range(ws):
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r]).all()


def _test_idist_all_gather_tensors_with_shapes_group(device):
if idist.get_world_size() > 1:
torch.manual_seed(41)

rank = idist.get_rank()
ranks = list(range(1, idist.get_world_size()))
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in ranks], ranks)
if rank in ranks:
for r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r - 1]).all()
else:
assert [rank_tensor] == tensors


def _test_distrib_broadcast(device):
rank = idist.get_rank()
ws = idist.get_world_size()
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -163,6 +165,8 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_shapes_group, (device,), np=np, do_init=True)


@pytest.mark.distributed
Expand Down
19 changes: 19 additions & 0 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -253,6 +255,23 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
_test_distrib_all_gather_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_all_gather_tensors_with_shapes_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
def test_idist_all_gather_tensors_with_shapes_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_test_distrib_barrier,
_test_distrib_broadcast,
_test_distrib_new_group,
_test_idist_all_gather_tensors_with_shapes,
_test_sync,
)

Expand Down Expand Up @@ -70,13 +71,15 @@ def test_idist__model_methods_no_dist():
def test_idist_collective_ops_no_dist():
_test_distrib_all_reduce("cpu")
_test_distrib_all_gather("cpu")
_test_idist_all_gather_tensors_with_shapes("cpu")
_test_distrib_barrier("cpu")
_test_distrib_broadcast("cpu")
_test_distrib_new_group("cpu")

if torch.cuda.device_count() > 1:
_test_distrib_all_reduce("cuda")
_test_distrib_all_gather("cuda")
_test_idist_all_gather_tensors_with_shapes("cuda")
_test_distrib_barrier("cuda")
_test_distrib_broadcast("cuda")
_test_distrib_new_group("cuda")
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_shapes,
_test_idist_all_gather_tensors_with_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -151,6 +153,8 @@ def test_idist_all_gather_xla():
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)
_test_idist_all_gather_tensors_with_shapes(device)
_test_idist_all_gather_tensors_with_shapes_group(device)


def _test_idist_all_gather_xla_in_child_proc(index):
Expand Down

0 comments on commit 680ac7f

Please sign in to comment.