Skip to content

Commit

Permalink
Implement the feature with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent e2f9ac0 commit 7c66841
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 2 deletions.
35 changes: 34 additions & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import itertools
import socket
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import distributed as dist

from ignite.distributed.comp_models import (
_SerialModel,
Expand Down Expand Up @@ -43,6 +45,7 @@
"one_rank_only",
"new_group",
"one_rank_first",
"all_gather_tensors_with_shapes",
]

_model = _SerialModel()
Expand Down Expand Up @@ -350,6 +353,36 @@ def all_reduce(
return _model.all_reduce(tensor, op, group=group)


def all_gather_tensors_with_shapes(
tensor: torch.Tensor, shapes: Sequence[Sequence[int]], group: Optional[Union[Any, List[int]]] = None
) -> List[torch.Tensor]:
"""Gather tensors with different shapes but with the same number of dimensions from across processes."""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

if isinstance(group, list) and all(isinstance(item, int) for item in group):
group = _model.new_group(group)

if isinstance(_model, _SerialModel) or group == dist.GroupMember.NON_GROUP_MEMBER:
return [tensor]

max_shape = torch.tensor(shapes).amax(dim=0)
padding_sizes = (max_shape - torch.tensor(tensor.shape)).tolist()
padded_tensor = torch.nn.functional.pad(
tensor, tuple(itertools.chain.from_iterable(map(lambda dim_size: (0, dim_size), reversed(padding_sizes))))
)
all_padded_tensors: torch.Tensor = _model.all_gather(padded_tensor, group=group) # .split(max_shape[0], dim=0)
return [
all_padded_tensors[
[
slice(rank * max_shape[0] if dim == 0 else 0, rank * max_shape[0] + dim_size if dim == 0 else dim_size)
for dim, dim_size in enumerate(shape)
]
]
for rank, shape in enumerate(shapes)
]


def all_gather(
tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None
) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]:
Expand Down
52 changes: 51 additions & 1 deletion tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist

import ignite.distributed as idist
from ignite.distributed.utils import sync
from ignite.distributed.utils import all_gather_tensors_with_shapes, sync
from ignite.engine import Engine, Events


Expand Down Expand Up @@ -291,6 +291,56 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group="abc")


def _test_idist_all_gather_tensors_with_different_shapes(device):
rank = idist.get_rank()
ws = idist.get_world_size()
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)])
for r in range(ws):
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r]).all()


def _test_idist_all_gather_tensors_with_different_shapes_group(device):
rank = idist.get_rank()
ranks = list(range(idist.get_world_size() - 1, 0, -1)) # [0, 1, 2, 3] -> [3, 2, 1]
ws = idist.get_world_size()
bnd = idist.backend()
if rank in ranks:
reference = torch.randn(ws * (ws + 1) // 2, ws * (ws + 3) // 2, ws * (ws + 5) // 2, device=device)
rank_tensor = reference[
rank * (rank + 1) // 2 : rank * (rank + 1) // 2 + rank + 1,
rank * (rank + 3) // 2 : rank * (rank + 3) // 2 + rank + 2,
rank * (rank + 5) // 2 : rank * (rank + 5) // 2 + rank + 3,
]
else:
rank_tensor = torch.tensor([rank], device=device)
if bnd in ("horovod"):
with pytest.raises(NotImplementedError, match=r"all_gather with group for horovod is not implemented"):
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)], ranks)
else:
tensors = all_gather_tensors_with_shapes(rank_tensor, [[r + 1, r + 2, r + 3] for r in range(ws)], ranks)
for r in range(ws):
if r in ranks:
r_tensor = reference[
r * (r + 1) // 2 : r * (r + 1) // 2 + r + 1,
r * (r + 3) // 2 : r * (r + 3) // 2 + r + 2,
r * (r + 5) // 2 : r * (r + 5) // 2 + r + 3,
]
assert (r_tensor == tensors[r]).all()
else:
assert tensors == rank_tensor


def _test_distrib_broadcast(device):
rank = idist.get_rank()
ws = idist.get_world_size()
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_different_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -163,6 +165,8 @@ def test_idist_all_gather_hvd(gloo_hvd_executor):
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
gloo_hvd_executor(_test_distrib_all_gather, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_distrib_all_gather_group, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_different_shapes, (device,), np=np, do_init=True)
gloo_hvd_executor(_test_idist_all_gather_tensors_with_different_shapes_group, (device,), np=np, do_init=True)


@pytest.mark.distributed
Expand Down
19 changes: 19 additions & 0 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_different_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -253,6 +255,23 @@ def test_idist_all_gather_gloo(distributed_context_single_node_gloo):
_test_distrib_all_gather_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_all_gather_tensors_with_different_shapes_nccl(distributed_context_single_node_nccl):
device = idist.device()
_test_idist_all_gather_tensors_with_different_shapes(device)
_test_idist_all_gather_tensors_with_different_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
def test_idist_all_gather_tensors_with_different_shapes_gloo(distributed_context_single_node_gloo):
device = idist.device()
_test_idist_all_gather_tensors_with_different_shapes(device)
_test_idist_all_gather_tensors_with_different_shapes_group(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/distributed/utils/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_test_distrib_barrier,
_test_distrib_broadcast,
_test_distrib_new_group,
_test_idist_all_gather_tensors_with_different_shapes,
_test_sync,
)

Expand Down Expand Up @@ -70,13 +71,15 @@ def test_idist__model_methods_no_dist():
def test_idist_collective_ops_no_dist():
_test_distrib_all_reduce("cpu")
_test_distrib_all_gather("cpu")
_test_idist_all_gather_tensors_with_different_shapes("cpu")
_test_distrib_barrier("cpu")
_test_distrib_broadcast("cpu")
_test_distrib_new_group("cpu")

if torch.cuda.device_count() > 1:
_test_distrib_all_reduce("cuda")
_test_distrib_all_gather("cuda")
_test_idist_all_gather_tensors_with_different_shapes("cuda")
_test_distrib_barrier("cuda")
_test_distrib_broadcast("cuda")
_test_distrib_new_group("cuda")
4 changes: 4 additions & 0 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
_test_distrib_new_group,
_test_distrib_one_rank_only,
_test_distrib_one_rank_only_with_engine,
_test_idist_all_gather_tensors_with_different_shapes,
_test_idist_all_gather_tensors_with_different_shapes_group,
_test_sync,
)

Expand Down Expand Up @@ -151,6 +153,8 @@ def test_idist_all_gather_xla():
device = idist.device()
_test_distrib_all_gather(device)
_test_distrib_all_gather_group(device)
_test_idist_all_gather_tensors_with_different_shapes(device)
_test_idist_all_gather_tensors_with_different_shapes_group(device)


def _test_idist_all_gather_xla_in_child_proc(index):
Expand Down

0 comments on commit 7c66841

Please sign in to comment.