Skip to content

Commit

Permalink
[Core][Distributed] fix _is_full_nvlink detection (vllm-project#4233)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Apr 22, 2024
1 parent 95e5b08 commit 747b1a7
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import contextmanager
from typing import Optional
from typing import List, Optional

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -53,14 +54,20 @@ def init_custom_ar() -> None:
return False
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size)
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = list(
map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warn(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
Expand Down Expand Up @@ -138,23 +145,28 @@ def _nvml():
pynvml.nvmlShutdown()


# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml()
def _is_full_nvlink(rank, world_size):
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
for i in range(world_size):
if i != rank:
try:
peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
def _is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
so it works on real physical device ids.
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
return False
return True


Expand Down

0 comments on commit 747b1a7

Please sign in to comment.