Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: modify rope for llama-3 and support llama-3.2 #131

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"typing-extensions",
"torch>=2",
"transformers>=4",
"litgpt[all]==0.4.10",
"litgpt[all]==0.5.0",
"syne-tune[moo]>=0.13",
"torchvision>=0.18",
]
Expand Down
18 changes: 18 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,24 @@ def test_llama_3_1():
assert torch.allclose(whittle_out, lit_out, atol=1e-3)


def test_llama_3_2():
config_llama = Config.from_name(
"Llama-3.2-1B",
n_layer=2,
n_embd=32,
intermediate_size=86,
padded_vocab_size=10000,
)
config_llama.fix_head_size = True
lit_model = LitGPT(config_llama)
whittle_model = GPT(config_llama)
copy_weights(lit_model, whittle_model)
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
whittle_out = whittle_model(x)
lit_out = lit_model(x)
assert torch.allclose(whittle_out, lit_out, atol=1e-3)


def test_gemma_2():
config_gemma = Config.from_name(
"gemma-2-9b",
Expand Down
91 changes: 73 additions & 18 deletions whittle/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
import torch.nn as nn
from litgpt import Config
from litgpt.model import build_rope_cache

from litgpt.model import batched_index_select
from whittle.models.gpt.blocks import Block
from whittle.modules.embedding import Embedding
from whittle.modules.layernorm import LayerNorm
from whittle.modules.linear import Linear
from whittle.modules.rmsnorm import RMSNorm


class GPT(nn.Module):
class GPT(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we import torch.nn above, so we can use nn.Module here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this

def __init__(self, config: Config) -> None:
super().__init__()
assert config.padded_vocab_size is not None
Expand Down Expand Up @@ -82,18 +82,26 @@ def max_seq_length(self, value: int) -> None:
self._max_seq_length = value
if not hasattr(self, "cos"):
# first call
cos, sin = self.rope_cache()
cos, sin = self.rope_cache(self._max_seq_length, self.config.rope_n_elem)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
# override
elif value != self.cos.size(0):
self.cos, self.sin = self.rope_cache(device=self.cos.device)
self.cos, self.sin = self.rope_cache(
seq_len=self._max_seq_length,
n_elem=self.config.rope_n_elem,
device=self.cos.device,
)
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
# if the kv cache is expected

def reset_parameters(self) -> None:
# Trigger resetting the rope-cache
self.cos, self.sin = self.rope_cache(device=self.cos.device)
self.cos, self.sin = self.rope_cache(
seq_len=self._max_seq_length,
n_elem=self.config.rope_n_elem,
device=self.cos.device,
)

def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
Expand All @@ -109,14 +117,57 @@ def tie_weights(self) -> None:
self.transformer.wte.weight = self.lm_head.weight

def rope_cache(
self, device: torch.device | None = None
self, seq_len: int, n_elem: int, device: torch.device | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if self.config.rope_adjustments is None:
extra_config = None

else:
adjusted_params_required = [
"factor",
"low_freq_factor",
"high_freq_factor",
"original_max_seq_len",
]
params_present = [
param in self.config.rope_adjustments
for param in adjusted_params_required
]
num_params_present = sum(params_present)

if num_params_present == 0:
extra_config = None # uses standard RoPE
elif num_params_present == 4:
# These parameters should always be used together so that we don't interfere with standard rope
extra_config = {
"original_max_seq_len": self.config.rope_adjustments[
"original_max_seq_len"
],
"factor": self.config.rope_adjustments["factor"],
"low_freq_factor": self.config.rope_adjustments["low_freq_factor"],
"high_freq_factor": self.config.rope_adjustments[
"high_freq_factor"
],
}
else:
# Some but not all parameters are specified; raise an error
missing_params = [
param
for param, present in zip(adjusted_params_required, params_present)
if not present
]
raise ValueError(
f"The following adjusted RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. "
"All adjusted RoPE parameters must be specified together."
)

return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
seq_len=seq_len,
n_elem=n_elem,
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
extra_config=extra_config,
)

def set_sub_network(
Expand Down Expand Up @@ -152,8 +203,8 @@ def set_sub_network(
block = self.transformer.h[j]
block.set_sub_network(
sub_network_n_embd,
sub_network_intermediate_size[i],
sub_network_num_heads[i],
sub_network_intermediate_size[j],
sub_network_num_heads[j],
sub_network_query_groups,
sub_network_head_size,
sample_random_indices,
Expand Down Expand Up @@ -185,11 +236,15 @@ def reset_super_network(self):

def process_rope_cache(self, cos, sin, input_pos, T):
if input_pos is not None: # use the kv cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
cos = batched_index_select(self.cos, 0, input_pos)
sin = batched_index_select(self.sin, 0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
mask = batched_index_select(self.mask_cache, 2, input_pos)
if mask.dim() > 4:
# the mask cache has a batch dim of 1 in addition to the one
# we get if input_pos has a batch dimension
mask = mask.squeeze(1)
else:
cos = cos[:T]
sin = sin[:T]
Expand All @@ -212,16 +267,16 @@ def forward(
block = self.transformer.h[j]
if not self.config.fix_head_size:
if isinstance(self.sub_network_num_heads, list):
cos, sin = build_rope_cache(
cos, sin = self.rope_cache(
seq_len=self.max_seq_length,
n_elem=int(
self.config.rotary_percentage
* (self.sub_network_n_embd // self.sub_network_num_heads[i])
* (self.sub_network_n_embd // self.sub_network_num_heads[j])
),
device=self.device,
)
else:
cos, sin = build_rope_cache(
cos, sin = self.rope_cache(
seq_len=self.max_seq_length,
n_elem=int(
self.config.rotary_percentage
Expand All @@ -231,15 +286,15 @@ def forward(
)
else:
if self.sub_network_head_size is None:
cos, sin = build_rope_cache(
cos, sin = self.rope_cache(
seq_len=self.max_seq_length,
n_elem=int(
self.config.rotary_percentage * (self.config.head_size)
),
device=self.device,
)
else:
cos, sin = build_rope_cache(
cos, sin = self.rope_cache(
seq_len=self.max_seq_length,
n_elem=int(
self.config.rotary_percentage * (self.sub_network_head_size)
Expand Down