Skip to content

Commit

Permalink
Merge branch 'main' into llama3.2-tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Oct 4, 2024
2 parents 2d6865f + ae798ab commit 4b5fb99
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 3 deletions.
8 changes: 7 additions & 1 deletion litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,15 @@ def main(
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")

model = fabric.setup_module(model)

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())

from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(device=old_embedding.weight.device, dtype=old_embedding.weight.dtype)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

Expand Down
8 changes: 7 additions & 1 deletion litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,15 @@ def main(
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")

model = fabric.setup_module(model)

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())

from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(device=old_embedding.weight.device, dtype=old_embedding.weight.dtype)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

Expand Down
8 changes: 7 additions & 1 deletion litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,15 @@ def main(
fabric.print(f"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}")

model = fabric.setup_module(model)

if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())

from bitsandbytes.nn import StableEmbedding
old_embedding = model.transformer.wte
model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)
with torch.no_grad():
model.transformer.wte.weight.copy_(old_embedding.weight)
model.transformer.wte = model.transformer.wte.to(device=old_embedding.weight.device, dtype=old_embedding.weight.dtype)
else:
optimizer = instantiate_torch_optimizer(optimizer, model.parameters())

Expand Down
2 changes: 2 additions & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca
assert dtype_to_name == {
"torch.float16": {
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
"transformer.h.0.norm_1.weight",
"transformer.h.0.norm_1.bias",
"transformer.h.0.attn.gating_factor",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp
"transformer.h.1.mlp.fc.adapter_scale",
"transformer.h.1.attn.attn.linear.bias",
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
"transformer.h.0.norm_2.weight",
"transformer.h.1.mlp.proj.linear.bias",
"transformer.h.0.attn.gating_factor",
Expand Down
3 changes: 3 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa

args, kwargs = train_mock.call_args
fabric, model, optimizer, *_ = args
model.transformer.wte = model.transformer.wte.half()
assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)
assert isinstance(optimizer, _FabricOptimizer)
assert isinstance(optimizer._optimizer, PagedAdamW)
Expand All @@ -748,6 +749,8 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa
"transformer.h.0.attn.attn.lora_B",
"transformer.h.0.norm_2.weight",
"transformer.wte.weight",
"transformer.wte.norm.weight",
"transformer.wte.norm.bias",
"transformer.h.1.mlp.fc.linear.bias",
"transformer.ln_f.bias",
"transformer.h.1.attn.attn.lora_B",
Expand Down

0 comments on commit 4b5fb99

Please sign in to comment.