diff --git a/litgpt/config.py b/litgpt/config.py index e047252d74..4ecbccca6c 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -62,6 +62,7 @@ class Config: intermediate_size: Optional[int] = None rope_condense_ratio: int = 1 rope_base: int = 10000 + rope_adjustments: Optional[dict] = None n_expert: int = 0 n_expert_per_token: int = 0 attention_logit_softcapping: Optional[float] = None @@ -893,6 +894,7 @@ def norm_class(self) -> Type: mlp_class_name="LLaMAMLP", intermediate_size=14336, rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), # https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json dict( @@ -931,6 +933,7 @@ def norm_class(self) -> Type: mlp_class_name="LLaMAMLP", intermediate_size=28672, rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), # https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json dict( @@ -950,8 +953,9 @@ def norm_class(self) -> Type: mlp_class_name="LLaMAMLP", intermediate_size=53248, rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), - # https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json + # https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json dict( name="Llama-3.2-1B{}", hf_config=dict(org="meta-llama", name="Llama-3.2-1B{}"), @@ -969,6 +973,7 @@ def norm_class(self) -> Type: mlp_class_name="LLaMAMLP", intermediate_size=8192, rope_base=500000, + rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), # https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json dict( @@ -988,6 +993,7 @@ def norm_class(self) -> Type: mlp_class_name="LLaMAMLP", intermediate_size=8192, rope_base=500000, + rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) ), ] for c in llama_3: diff --git a/litgpt/model.py b/litgpt/model.py index 3768b58708..4f85da805a 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -108,12 +108,39 @@ def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.config.rope_adjustments is None: + extra_config = None + + else: + adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] + params_present = [param in self.config.rope_adjustments for param in adjusted_params_required] + num_params_present = sum(params_present) + + if num_params_present == 0: + extra_config = None # uses standard RoPE + elif num_params_present == 4: + # These parameters should always be used together so that we don't interfere with standard rope + extra_config = { + "original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"], + "factor": self.config.rope_adjustments["factor"], + "low_freq_factor": self.config.rope_adjustments["low_freq_factor"], + "high_freq_factor": self.config.rope_adjustments["high_freq_factor"], + } + else: + # Some but not all parameters are specified; raise an error + raise ValueError( + "The following adjusted RoPE parameters are missing in rope_adjustments." + "All adjusted RoPE parameters must be specified together." + ) + return build_rope_cache( seq_len=self.max_seq_length, n_elem=self.config.rope_n_elem, device=device, condense_ratio=self.config.rope_condense_ratio, base=self.config.rope_base, + extra_config=extra_config, ) def set_kv_cache( @@ -410,17 +437,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def build_rope_cache( - seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1 + seq_len: int, + n_elem: int, + device: Optional[torch.device] = None, + base: int = 10000, + condense_ratio: int = 1, + extra_config: Optional[dict] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Enhanced Transformer with Rotary Position Embedding. Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ transformers/rope/__init__.py. MIT License: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + + Args: + seq_len (int): Sequence length. + n_elem (int): Number of elements (head dimension). + device (torch.device, optional): Device for tensor allocations. + base (int, optional): Base for computing inverse frequencies. + condense_ratio (int, optional): Ratio to condense the position indices. + extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even" theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + if extra_config is not None: + # Extract configuration parameters + orig_context_len = extra_config["original_max_seq_len"] + factor = extra_config["factor"] + low_freq_factor = extra_config["low_freq_factor"] + high_freq_factor = extra_config["high_freq_factor"] + + # Compute wavelength thresholds + low_freq_wavelen = orig_context_len / low_freq_factor + high_freq_wavelen = orig_context_len / high_freq_factor + + # Compute wavelengths corresponding to the inverse frequencies + wavelen = 2 * torch.pi / theta + + # Initialize adjusted inverse frequencies + adjusted_theta = theta.clone() + + # Low Frequency Region: wavelen > low_freq_wavelen + mask_low_freq = wavelen > low_freq_wavelen + adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor + + # Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen + mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen) + # Compute smooth factor for medium frequencies + ratio = orig_context_len / wavelen[mask_medium_freq] + smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) + # Interpolate inverse frequencies + adjusted_theta[mask_medium_freq] = ( + (1 - smooth_factor) * (theta[mask_medium_freq] / factor) + + smooth_factor * theta[mask_medium_freq] + ) + theta = adjusted_theta + # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, device=device) / condense_ratio diff --git a/tests/test_rope.py b/tests/test_rope.py index 6bed37dc21..14ea33c0aa 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -1,13 +1,17 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import torch -from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb +from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding +from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama +from transformers.models.llama.configuration_llama import LlamaConfig from litgpt.model import apply_rope, build_rope_cache @torch.inference_mode() -def test_rope(): +def test_rope_gptneox(): bs, seq_len, n_head, n_embed = 1, 6, 2, 8 head_size = n_embed // n_head x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float() @@ -22,5 +26,200 @@ def test_rope(): torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze()) ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached) - theirs_x_rope, _ = apply_rotary_pos_emb(x, x, theirs_cos, theirs_sin, position_ids) + theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids) torch.testing.assert_close(ours_x_rope, theirs_x_rope) + + +@torch.inference_mode() +def test_rope_llama_2(): + head_dim = 64 + rope_theta = 10_000 + + ################################## + # Compare cos and sin + ################################## + # transformer rope + rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta) + batch_size, seq_len = 1, 10 + qk_tensor = torch.randn(batch_size, seq_len, head_dim) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) + + # our rope + ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta) + torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) + torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + + ################################## + # Compare rotated tensors + ################################## + # Settings + num_heads = 4 + + # Dummy query and key tensors + torch.manual_seed(123) + queries = torch.randn(batch_size, num_heads, seq_len, head_dim) + keys = torch.randn(batch_size, num_heads, seq_len, head_dim) + + ours_q_rot = apply_rope(queries, ours_cos, ours_sin) + ours_k_rot = apply_rope(keys, ours_cos, ours_sin) + theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) + torch.testing.assert_close(theirs_q_rot, ours_q_rot) + torch.testing.assert_close(theirs_k_rot, ours_k_rot) + + +# See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings +@torch.inference_mode() +def test_rope_llama_3(): + head_dim = 64 + rope_theta = 50_000 + + ################################## + # Compare cos and sin + ################################## + # transformer rope + rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta) + batch_size, seq_len = 1, 10 + qk_tensor = torch.randn(batch_size, seq_len, head_dim) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) + + # our rope + ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta) + torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) + torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + + ################################## + # Compare rotated tensors + ################################## + # Settings + num_heads = 4 + + # Dummy query and key tensors + torch.manual_seed(123) + queries = torch.randn(batch_size, num_heads, seq_len, head_dim) + keys = torch.randn(batch_size, num_heads, seq_len, head_dim) + + ours_q_rot = apply_rope(queries, ours_cos, ours_sin) + ours_k_rot = apply_rope(keys, ours_cos, ours_sin) + theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) + torch.testing.assert_close(theirs_q_rot, ours_q_rot) + torch.testing.assert_close(theirs_k_rot, ours_k_rot) + + +# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings +@torch.inference_mode() +def test_rope_llama_3_1(): + head_dim = 128 + rope_theta = 50_000 + + their_rope_config = { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + } + + our_rope_config = { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_seq_len": 8192 + } + + config = LlamaConfig( + rope_theta=rope_theta, + rope_scaling=their_rope_config + ) + + ################################## + # Compare cos and sin + ################################## + # transformer rope + rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") + batch_size, seq_len = 1, 10 + qk_tensor = torch.randn(batch_size, seq_len, head_dim) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) + + # our rope + ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) + torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) + torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + + ################################## + # Compare rotated tensors + ################################## + # Settings + num_heads = 4 + + # Dummy query and key tensors + torch.manual_seed(123) + queries = torch.randn(batch_size, num_heads, seq_len, head_dim) + keys = torch.randn(batch_size, num_heads, seq_len, head_dim) + + ours_q_rot = apply_rope(queries, ours_cos, ours_sin) + ours_k_rot = apply_rope(keys, ours_cos, ours_sin) + theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) + torch.testing.assert_close(theirs_q_rot, ours_q_rot) + torch.testing.assert_close(theirs_k_rot, ours_k_rot) + + +# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings +@torch.inference_mode() +def test_rope_llama_3_2(): + head_dim = 128 + rope_theta = 50_000 + + their_rope_config = { + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + } + + our_rope_config = { + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_seq_len": 8192 + } + + config = LlamaConfig( + rope_theta=rope_theta, + rope_scaling=their_rope_config + ) + + ################################## + # Compare cos and sin + ################################## + # transformer rope + rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3") + batch_size, seq_len = 1, 10 + qk_tensor = torch.randn(batch_size, seq_len, head_dim) + position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids) + + # our rope + ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) + torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) + torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + + ################################## + # Compare rotated tensors + ################################## + # Settings + num_heads = 4 + + # Dummy query and key tensors + torch.manual_seed(123) + queries = torch.randn(batch_size, num_heads, seq_len, head_dim) + keys = torch.randn(batch_size, num_heads, seq_len, head_dim) + + ours_q_rot = apply_rope(queries, ours_cos, ours_sin) + ours_k_rot = apply_rope(keys, ours_cos, ours_sin) + theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) + torch.testing.assert_close(theirs_q_rot, ours_q_rot) + torch.testing.assert_close(theirs_k_rot, ours_k_rot) \ No newline at end of file