Skip to content

Commit

Permalink
Improve rope (#1745)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Oct 4, 2024
1 parent a8aa4ba commit c03f3f0
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 6 deletions.
8 changes: 7 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Config:
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
rope_adjustments: Optional[dict] = None
n_expert: int = 0
n_expert_per_token: int = 0
attention_logit_softcapping: Optional[float] = None
Expand Down Expand Up @@ -893,6 +894,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json
dict(
Expand Down Expand Up @@ -931,6 +933,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json
dict(
Expand All @@ -950,8 +953,9 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=53248,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json
# https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
dict(
name="Llama-3.2-1B{}",
hf_config=dict(org="meta-llama", name="Llama-3.2-1B{}"),
Expand All @@ -969,6 +973,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=500000,
rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json
dict(
Expand All @@ -988,6 +993,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=500000,
rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
]
for c in llama_3:
Expand Down
81 changes: 79 additions & 2 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,39 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))

def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:

if self.config.rope_adjustments is None:
extra_config = None

else:
adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"]
params_present = [param in self.config.rope_adjustments for param in adjusted_params_required]
num_params_present = sum(params_present)

if num_params_present == 0:
extra_config = None # uses standard RoPE
elif num_params_present == 4:
# These parameters should always be used together so that we don't interfere with standard rope
extra_config = {
"original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"],
"factor": self.config.rope_adjustments["factor"],
"low_freq_factor": self.config.rope_adjustments["low_freq_factor"],
"high_freq_factor": self.config.rope_adjustments["high_freq_factor"],
}
else:
# Some but not all parameters are specified; raise an error
raise ValueError(
"The following adjusted RoPE parameters are missing in rope_adjustments."
"All adjusted RoPE parameters must be specified together."
)

return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
extra_config=extra_config,
)

def set_kv_cache(
Expand Down Expand Up @@ -410,17 +437,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def build_rope_cache(
seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
Args:
seq_len (int): Sequence length.
n_elem (int): Number of elements (head dimension).
device (torch.device, optional): Device for tensor allocations.
base (int, optional): Base for computing inverse frequencies.
condense_ratio (int, optional): Ratio to condense the position indices.
extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$

# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even"
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
# Extract configuration parameters
orig_context_len = extra_config["original_max_seq_len"]
factor = extra_config["factor"]
low_freq_factor = extra_config["low_freq_factor"]
high_freq_factor = extra_config["high_freq_factor"]

# Compute wavelength thresholds
low_freq_wavelen = orig_context_len / low_freq_factor
high_freq_wavelen = orig_context_len / high_freq_factor

# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

Expand Down
205 changes: 202 additions & 3 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import torch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama
from transformers.models.llama.configuration_llama import LlamaConfig

from litgpt.model import apply_rope, build_rope_cache


@torch.inference_mode()
def test_rope():
def test_rope_gptneox():
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
Expand All @@ -22,5 +26,200 @@ def test_rope():
torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze())

ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)
theirs_x_rope, _ = apply_rotary_pos_emb(x, x, theirs_cos, theirs_sin, position_ids)
theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids)
torch.testing.assert_close(ours_x_rope, theirs_x_rope)


@torch.inference_mode()
def test_rope_llama_2():
head_dim = 64
rope_theta = 10_000

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)


# See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3():
head_dim = 64
rope_theta = 50_000

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)


# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_1():
head_dim = 128
rope_theta = 50_000

their_rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}

our_rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)


# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_2():
head_dim = 128
rope_theta = 50_000

their_rope_config = {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}

our_rope_config = {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)

0 comments on commit c03f3f0

Please sign in to comment.