Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMD (MI250X) support #1775

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,8 @@ def build_rope_cache(
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.

Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
Enhanced Transformer with Rotary Position Embedding.

Args:
seq_len (int): Sequence length.
Expand All @@ -463,7 +460,7 @@ def build_rope_cache(
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""

# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even"
# Compute the inverse frequencies theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
Expand All @@ -480,26 +477,19 @@ def build_rope_cache(
# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
# Compute ratio across all elements
ratio = orig_context_len / wavelen

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
# Compute smooth_factor and clamp between 0 and 1
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta

theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
# Create position indices `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

# Calculate the product of position index and $\theta_i$
Expand Down
131 changes: 98 additions & 33 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,47 +637,112 @@ def auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_fil


def check_nvlink_connectivity(fabric=None):
"""Checks GPU connectivity for both NVIDIA and AMD GPUs.

This function delegates to vendor-specific implementations based on
the detected GPU vendor.
"""
if fabric is not None:
custom_print = fabric.print
else:
custom_print = print

if os.getenv("RANK", "0") == "0":
try:
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)

if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.split('\n')
gpu_matrix = []

start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None) + 1
headers_line = lines[start_index - 1]
headers = headers_line.split()
# The regex is to avoid counting the "GPU NUMA ID" header as a GPU
# in headers like ['\x1b[4mGPU0', 'GPU1', 'GPU2', 'GPU3', 'GPU4', 'GPU5', 'GPU6', 'GPU7', 'NIC0', 'NIC1', 'NIC2', 'NIC3', 'NIC4', 'NIC5', 'NIC6', 'NIC7', 'NIC8', 'NIC9', 'CPU', 'Affinity', 'NUMA', 'Affinity', 'GPU', 'NUMA', 'ID\x1b[0m']
gpu_regex = re.compile(r'^GPU\d+$')
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index:start_index + gpu_count]:
gpu_matrix.append(line.strip())
connections = line.split()[1:1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
if torch.cuda.is_available():
device_properties = torch.cuda.get_device_properties(0)
gpu_name = device_properties.name.lower()
if "nvidia" in gpu_name:
_check_nvidia_connectivity(custom_print)
elif "advanced micro devices" in gpu_name or "amd" in gpu_name:
_check_amd_connectivity(custom_print)
else:
custom_print(f"Unrecognized GPU vendor: {device_properties.name}")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)

custom_print("No GPUs available")
except Exception as e:
custom_print(f"An error occurred: {e}")
custom_print(f"An error occurred while checking GPU connectivity: {e}")


def _check_nvidia_connectivity(custom_print):
"""Checks NVLink connectivity on NVIDIA GPUs."""
result = subprocess.run(["nvidia-smi", "topo", "-m"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run nvidia-smi")
return

lines = result.stdout.strip().split("\n")
start_index = next((i for i, line in enumerate(lines) if "GPU0" in line), None)
if start_index is None:
custom_print("Failed to parse nvidia-smi output")
return

headers_line = lines[start_index]
headers = headers_line.split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

all_nvlink = True
for line in lines[start_index + 1 : start_index + 1 + gpu_count]:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
if not all("NV" in conn for conn in connections if conn != "X"):
all_nvlink = False
break

if all_nvlink:
custom_print("All GPUs are fully connected via NVLink.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def _check_amd_connectivity(custom_print):
"""Checks XGMI connectivity on AMD GPUs."""
result = subprocess.run(["rocm-smi", "--showtopotype"], stdout=subprocess.PIPE, text=True)
if result.returncode != 0:
custom_print("Failed to run rocm-smi")
return

lines = result.stdout.strip().split("\n")
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
return

header_line = lines[gpu_header_index - 1]
headers = header_line.strip().split()
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
gpu_lines.append(line.strip())
if len(gpu_lines) != gpu_count:
custom_print("Mismatch in GPU count when parsing rocm-smi output")
return

all_xgmi = True
for line in gpu_lines:
columns = line.split()
connections = columns[1 : 1 + gpu_count]
for conn in connections:
if conn not in ("XGMI", "0"):
all_xgmi = False
break
if not all_xgmi:
break

if all_xgmi:
custom_print("All GPUs are fully connected via XGMI.")
else:
custom_print(
"Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. "
"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance."
)


def fix_and_load_json(s):
Expand Down
120 changes: 116 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,29 @@ def test_file_size_above_limit_on_gpu():
assert size == 4_600_000_000


@pytest.fixture
def mock_cuda_is_available_true(monkeypatch):
"""Fixture to mock torch.cuda.is_available() to return True."""
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)


@pytest.fixture
def mock_nvidia_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "NVIDIA RTX A6000"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)


@pytest.fixture
def mock_amd_device_properties(monkeypatch):
"""Fixture to mock torch.cuda.get_device_properties() for AMD GPUs."""
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "AMD Instinct MI250X"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)



@pytest.fixture
def all_nvlink_connected_output():
return mock.MagicMock(stdout=""" GPU0 GPU1 GPU2 GPU3
Expand All @@ -475,7 +498,7 @@ def all_nvlink_connected_output():


@mock.patch("subprocess.run")
def test_all_nvlink_connected(mock_run, all_nvlink_connected_output):
def test_all_nvlink_connected(mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties):
mock_run.return_value = all_nvlink_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
Expand All @@ -497,7 +520,7 @@ def nvlink_partially_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output):
def test_nvlink_partially_connected_output(mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties):
mock_run.return_value = nvlink_partially_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
Expand Down Expand Up @@ -527,7 +550,7 @@ def nvlink_not_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output):
def test_nvlink_not_connected_output(mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties):
mock_run.return_value = nvlink_not_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
Expand Down Expand Up @@ -586,13 +609,102 @@ def nvlink_all_gpu_connected_but_other_connected_output():


@mock.patch("subprocess.run")
def test_nvlink_all_gpu_connected_but_other_connected_output(mock_run, nvlink_all_gpu_connected_but_other_connected_output):
def test_nvlink_all_gpu_connected_but_other_connected_output(
mock_run,
nvlink_all_gpu_connected_but_other_connected_output,
mock_cuda_is_available_true,
mock_nvidia_device_properties,
):
mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def nvidia_smi_nvlink_output_dual_gpu_no_numa():
return mock.MagicMock(
stdout="""
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV1 0-15 0 N/A
GPU1 NV1 X 0-15 0 N/A

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
""",
returncode=0,
)


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(
mock_run, nvidia_smi_nvlink_output_dual_gpu_no_numa, mock_cuda_is_available_true, mock_nvidia_device_properties
):
mock_run.return_value = nvidia_smi_nvlink_output_dual_gpu_no_numa
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via NVLink.")


@pytest.fixture
def rocm_smi_xgmi_output_multi_gpu():
"""
rocm-smi --showtopotype on ROCm 6.0.3+
"""
return mock.MagicMock(
stdout="""
=============================== ROCm System Management Interface ============================
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
================================== End of ROCm SMI Log ===================================
""",
returncode=0,
)


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus(
mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties
):
mock_run.return_value = rocm_smi_xgmi_output_multi_gpu
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("All GPUs are fully connected via XGMI.")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("No GPUs available")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("Unrecognized GPU vendor: GARAGE DIY HYPERSCALER GPU")


def test_fix_and_load_json():
# Test 1: Invalid JSON string with a trailing comma
invalid_json_trailing_comma = '''
Expand Down
Loading