Skip to content

Commit

Permalink
add toggler to disable the using the nccl base collectives (#799)
Browse files Browse the repository at this point in the history
* add toggler to disable the using the nccl base collectives

* added todo to remove the toggle when the issue is resolved.
  • Loading branch information
tmarkstrum authored Sep 17, 2021
1 parent 180ab8c commit 086402d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 7 additions & 1 deletion fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import logging
from math import inf
import os
import time
import traceback
import typing
Expand Down Expand Up @@ -54,6 +55,11 @@

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True


class TrainingState(Enum):
Expand Down Expand Up @@ -1599,7 +1605,7 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
output_tensor = p._full_param_padded

# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base"):
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
Expand Down
11 changes: 9 additions & 2 deletions fairscale/utils/reduce_scatter_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
# LICENSE file in the root directory of this source tree.

import functools
import os
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup

# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True


class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
Expand All @@ -26,7 +33,7 @@ def flush(self) -> None:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base"):
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
)
Expand Down Expand Up @@ -130,7 +137,7 @@ def reduce_scatter_async(
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base"):
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
Expand Down

0 comments on commit 086402d

Please sign in to comment.