From 3a6462c72bdb6e5eaa0c78626241946bc375f833 Mon Sep 17 00:00:00 2001 From: Tensor Templar Date: Wed, 9 Oct 2024 11:31:38 +0300 Subject: [PATCH] Remove warning print, revert import sorting to previous style --- litgpt/utils.py | 27 +++++---------------------- tests/test_utils.py | 6 +++--- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/litgpt/utils.py b/litgpt/utils.py index 535b702fdf..82b7fd7bbf 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -11,28 +11,17 @@ import subprocess import sys import warnings -from dataclasses import asdict -from dataclasses import is_dataclass +from dataclasses import asdict, is_dataclass from io import BytesIO from pathlib import Path -from typing import Any -from typing import Dict -from typing import Iterable -from typing import List -from typing import Literal -from typing import Mapping -from typing import Optional -from typing import TYPE_CHECKING -from typing import TypeVar -from typing import Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union import lightning as L import torch import torch.nn as nn import torch.utils._device import yaml -from lightning.fabric.loggers import CSVLogger -from lightning.fabric.loggers import TensorBoardLogger +from lightning.fabric.loggers import CSVLogger, TensorBoardLogger from lightning.fabric.strategies import FSDPStrategy from lightning.fabric.utilities.load import _lazy_load as lazy_load from lightning.pytorch.cli import instantiate_class @@ -43,8 +32,7 @@ if TYPE_CHECKING: - from litgpt import Config - from litgpt import GPT + from litgpt import GPT, Config def init_out_dir(out_dir: Path) -> Path: @@ -472,9 +460,7 @@ def copy_config_files(source_dir: Path, out_dir: Path) -> None: def CLI(*args: Any, **kwargs: Any) -> Any: - from jsonargparse import CLI - from jsonargparse import set_config_read_mode - from jsonargparse import set_docstring_parse_options + from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options set_docstring_parse_options(attribute_docstrings=True) set_config_read_mode(urls_enabled=True) @@ -673,7 +659,6 @@ def check_nvlink_connectivity(fabric=None): else: custom_print = print - # Only execute on the primary process if os.getenv("RANK", "0") == "0": try: if torch.cuda.is_available(): @@ -734,7 +719,6 @@ def _check_amd_connectivity(custom_print): return lines = result.stdout.strip().split("\n") - # Find the line that starts with "GPU0" gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None) if gpu_header_index is None or gpu_header_index == 0: custom_print("Failed to parse rocm-smi output (no GPU headers found)") @@ -745,7 +729,6 @@ def _check_amd_connectivity(custom_print): gpu_regex = re.compile(r"^GPU\d+$") gpu_count = len([header for header in headers if gpu_regex.match(header)]) - # Collect GPU connection lines gpu_lines = [] for line in lines[gpu_header_index : gpu_header_index + gpu_count]: if re.match(r"^\s*GPU\d+", line): diff --git a/tests/test_utils.py b/tests/test_utils.py index 6895d3a71f..16cbeaa75c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -713,7 +713,7 @@ def rocm_smi_xgmi_output_multi_gpu(): @mock.patch("subprocess.run") -def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi_8_gpus( +def test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus( mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties ): mock_run.return_value = rocm_smi_xgmi_output_multi_gpu @@ -723,7 +723,7 @@ def test_check_nvlink_connectivity_returns_fully_connected_when_amd_all_xgmi_8_g @mock.patch("subprocess.run") -def test_check_nvlink_connectivity_returns_no_gpus_when_no_gpus(mock_run, monkeypatch): +def test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch): monkeypatch.setattr(torch.cuda, "is_available", lambda: False) with mock.patch("builtins.print") as mock_print: check_nvlink_connectivity() @@ -731,7 +731,7 @@ def test_check_nvlink_connectivity_returns_no_gpus_when_no_gpus(mock_run, monkey @mock.patch("subprocess.run") -def test_check_nvlink_connectivity_returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true): +def test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(mock_run, monkeypatch, mock_cuda_is_available_true): mock_device_properties = mock.MagicMock(name="GPU Device", spec=["name"]) mock_device_properties.name = "GARAGE DIY HYPERSCALER GPU" monkeypatch.setattr(torch.cuda, "get_device_properties", lambda idx: mock_device_properties)