Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA-FA support #860

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/fairseq2/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from fairseq2.models.llama.factory import LLaMAConfig as LLaMAConfig
from fairseq2.models.llama.factory import create_llama_model as create_llama_model
from fairseq2.models.llama.factory import get_llama_lora_config as get_llama_lora_config
from fairseq2.models.llama.factory import (
get_llama_lora_fa_config as get_llama_lora_fa_config,
)
from fairseq2.models.llama.factory import llama_arch as llama_arch
from fairseq2.models.llama.factory import llama_archs as llama_archs
from fairseq2.models.llama.loader import load_llama_config as load_llama_config
Expand Down
7 changes: 7 additions & 0 deletions src/fairseq2/models/llama/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,10 @@ def get_llama_lora_config() -> LoRAConfig:
dropout_p=0.05,
keys=[".*decoder.layers.*.self_attn.*(q_proj|v_proj)$"],
)


def get_llama_lora_fa_config() -> LoRAConfig:
config = get_llama_lora_config()
config.trainable_A = False

return config
13 changes: 8 additions & 5 deletions src/fairseq2/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LoRAConfig:
alpha: float
dropout_p: float
keys: list[str]
trainable_A: bool = True


class LoRALayer(ABC):
Expand Down Expand Up @@ -112,8 +113,8 @@ def forward(self, x: Tensor) -> Tensor:

def reset_lora_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
nn.init.normal_(self.lora_A)
nn.init.zeros_(self.lora_B)

@property
def wrapped_module(self) -> Embedding:
Expand Down Expand Up @@ -182,11 +183,13 @@ def __init__(
self.register_parameter("bias", None)

self.lora_A = Parameter(
torch.empty((self.r, self.input_dim), device=device, dtype=dtype)
torch.empty((self.r, self.input_dim), device=device, dtype=dtype),
requires_grad=config.trainable_A,
)

self.lora_B = Parameter(
torch.empty((self.output_dim, self.r), device=device, dtype=dtype)
torch.empty((self.output_dim, self.r), device=device, dtype=dtype),
requires_grad=True,
)

if self.dropout_p > 0.0:
Expand Down Expand Up @@ -346,7 +349,7 @@ def freeze_non_lora(
if isinstance(submodule, LoRALayer):
for param_name, param in submodule.named_parameters(recurse=False):
if param_name in ["lora_A", "lora_B"]:
param.requires_grad = True
continue
elif param_name == "bias" and unfreeze_bias in ["all", "lora_only"]:
param.requires_grad = True
else:
Expand Down
34 changes: 31 additions & 3 deletions tests/integration/models/test_llama_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import torch

from fairseq2.data import VocabularyInfo
from fairseq2.models.llama import LLaMAConfig, create_llama_model, get_llama_lora_config
from fairseq2.models.llama import (
LLaMAConfig,
create_llama_model,
get_llama_lora_config,
get_llama_lora_fa_config,
)
from fairseq2.nn.lora import (
freeze_non_lora,
merge_lora,
Expand Down Expand Up @@ -77,5 +82,28 @@ def test_lora_wrappers_llama_works() -> None:
freeze_non_lora(model, unfreeze_bias="none")

for name, param in model.named_parameters():
if param.requires_grad:
assert "lora_" in name
assert param.requires_grad == ("lora_" in name)


def test_lora_fa_freezes_llama_properly() -> None:
llama_config = LLaMAConfig(
model_dim=1024,
max_seq_len=2048,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=16,
num_attn_heads=8,
num_key_value_heads=8,
ffn_inner_dim=1024 * 4,
ffn_inner_dim_to_multiple=1,
dropout_p=0.1,
)
model = create_llama_model(llama_config, device=torch.device("cpu"))
lora_config = get_llama_lora_fa_config()

model = wrap_lora(model, lora_config) # type: ignore[assignment]
freeze_non_lora(model, unfreeze_bias="none")

for name, param in model.named_parameters():
assert param.requires_grad == ("lora_B" in name)
Loading