Skip to content

Commit

Permalink
Give up and rewrite build_rope_cache with vectorization instead. Add …
Browse files Browse the repository at this point in the history
…clarification comment in test

Revert failover to cpu in build_rope_cache when device is None

Add test fixture for amd multigpu xgmi and nvidia dualgpu nvlink. Update tests to use fixtures

Use fixtures for device properties. Follow existing style in fixture order. Mock subprocess.run for new tests

Use real device names in mocks

Remove redundant mocks

Remove warning print, revert import sorting to previous style

Remove warning print, revert import sorting to previous style
  • Loading branch information
TensorTemplar committed Oct 9, 2024
1 parent 3942bf1 commit 232e4fa
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 155 deletions.
49 changes: 14 additions & 35 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,8 @@ def build_rope_cache(
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
Enhanced Transformer with Rotary Position Embedding.
Args:
seq_len (int): Sequence length.
Expand All @@ -462,9 +459,8 @@ def build_rope_cache(
Returns:
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""
if device is None:
print("warning: build_rope_cache called without device, meta device custom ops may fail")
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even"

# Compute the inverse frequencies theta
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
Expand All @@ -481,39 +477,22 @@ def build_rope_cache(
# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()
# Compute ratio across all elements
ratio = orig_context_len / wavelen

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
# avoid NotImplementedError: aten::nonzero: attempted to run this operator with Meta tensors
if device is not None:
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor
else:
adjusted_theta = torch.where(
mask_low_freq,
theta / factor,
adjusted_theta
)
print(f"theta device: {theta.device}")
print(f"mask_low_freq device: {mask_low_freq.device}")

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
# Compute smooth_factor and clamp between 0 and 1
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)

# Compute adjusted_theta without masked indexing
adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta

theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
# Create position indices `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

# Calculate the product of position index and $\theta_i$
# Calculate the product of position index and θ_i
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)

return torch.cos(idx_theta), torch.sin(idx_theta)
Expand Down
27 changes: 5 additions & 22 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,17 @@
import subprocess
import sys
import warnings
from dataclasses import asdict
from dataclasses import is_dataclass
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union

import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
import yaml
from lightning.fabric.loggers import CSVLogger
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.cli import instantiate_class
Expand All @@ -43,8 +32,7 @@


if TYPE_CHECKING:
from litgpt import Config
from litgpt import GPT
from litgpt import GPT, Config


def init_out_dir(out_dir: Path) -> Path:
Expand Down Expand Up @@ -472,9 +460,7 @@ def copy_config_files(source_dir: Path, out_dir: Path) -> None:


def CLI(*args: Any, **kwargs: Any) -> Any:
from jsonargparse import CLI
from jsonargparse import set_config_read_mode
from jsonargparse import set_docstring_parse_options
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options

set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
Expand Down Expand Up @@ -673,7 +659,6 @@ def check_nvlink_connectivity(fabric=None):
else:
custom_print = print

# Only execute on the primary process
if os.getenv("RANK", "0") == "0":
try:
if torch.cuda.is_available():
Expand Down Expand Up @@ -734,7 +719,6 @@ def _check_amd_connectivity(custom_print):
return

lines = result.stdout.strip().split("\n")
# Find the line that starts with "GPU0"
gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r"^\s*GPU0", line)), None)
if gpu_header_index is None or gpu_header_index == 0:
custom_print("Failed to parse rocm-smi output (no GPU headers found)")
Expand All @@ -745,7 +729,6 @@ def _check_amd_connectivity(custom_print):
gpu_regex = re.compile(r"^GPU\d+$")
gpu_count = len([header for header in headers if gpu_regex.match(header)])

# Collect GPU connection lines
gpu_lines = []
for line in lines[gpu_header_index : gpu_header_index + gpu_count]:
if re.match(r"^\s*GPU\d+", line):
Expand Down
Loading

0 comments on commit 232e4fa

Please sign in to comment.