Skip to content

Commit

Permalink
Remove warning print, revert import sorting to previous style
Browse files Browse the repository at this point in the history
  • Loading branch information
TensorTemplar committed Oct 9, 2024
1 parent 89aa792 commit 3a6462c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 25 deletions.
27 changes: 5 additions & 22 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,17 @@
import subprocess
import sys
import warnings
from dataclasses import asdict
from dataclasses import is_dataclass
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union

import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
import yaml
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.cli import instantiate_class
Expand All @@ -43,8 +32,7 @@


if TYPE_CHECKING:
from litgpt import Config
from litgpt import GPT
from litgpt import GPT, Config


def init_out_dir(out_dir: Path) -> Path:
Expand Down Expand Up @@ -472,9 +460,7 @@ def copy_config_files(source_dir: Path, out_dir: Path) -> None:


def CLI(*args: Any, **kwargs: Any) -> Any:
from jsonargparse import CLI
from jsonargparse import set_config_read_mode
from jsonargparse import set_docstring_parse_options
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options

set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
Expand Down Expand Up @@ -673,7 +659,6 @@ def check_nvlink_connectivity(fabric=None):
else:
custom_print = print

# Only execute on the primary process
if os.getenv("RANK", "0") == "0":
try:
if torch.cuda.is_available():
Expand Down Expand Up @@ -734,7 +719,6 @@ def _check_amd_connectivity(custom_print):
return

lines = result.stdout.strip().split("\n")
# Find the line that starts with "GPU0"
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
Expand All @@ -745,7 +729,6 @@ def _check_amd_connectivity(custom_print):
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

# Collect GPU connection lines
gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def rocm_smi_xgmi_output_multi_gpu():


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi_8_gpus(
def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus(
mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties
):
mock_run.return_value = rocm_smi_xgmi_output_multi_gpu
Expand All @@ -723,15 +723,15 @@ def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi_8_g


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity_returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch):
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
with mock.patch("builtins.print") as mock_print:
check_nvlink_connectivity()
mock_print.assert_any_call("No GPUs available")


@mock.patch("subprocess.run")
def test_check_nvlink_connectivity_returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true):
def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true):
mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"])
mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU"
monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)
Expand Down

0 comments on commit 3a6462c

Please sign in to comment.