Skip to content

Commit

Permalink
Add Llama2 7B recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
suiyoubi committed Dec 18, 2024
1 parent 0133deb commit 3134cd0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
12 changes: 6 additions & 6 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
io,
)
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform
from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -558,7 +558,7 @@ def import_ckpt(
Python Usage:
```python
model = Mistral7BModel()
model = Mistral7BModelL940()
imported_path = import_ckpt(model, "hf://mistralai/Mistral-7B-v0.1")
```
Expand Down Expand Up @@ -877,10 +877,10 @@ def _setup(
trainer.callbacks.append(ModelTransform())
# Move jit callback at the end ensure it's applied on top of any model transformations (peft)
jit_cb = None
for i, cb in enumerate(trainer.callbacks):
if isinstance(cb, JitTransform):
assert jit_cb is None
jit_cb = trainer.callbacks.pop(i)
# for i, cb in enumerate(trainer.callbacks):
# if isinstance(cb, JitTransform):
# assert jit_cb is None
# jit_cb = trainer.callbacks.pop(i)
if jit_cb is not None:
trainer.callbacks.append(jit_cb)
return app_state
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
gemma_7b,
gpt3_175b,
hf_auto_model_for_causal_lm,
llama2_7b,
llama3_8b,
llama3_8b_16k,
llama3_8b_64k,
Expand Down Expand Up @@ -86,6 +87,7 @@
"chatglm3_6b",
"gemma_2b",
"gemma_7b",
"llama2_7b",
"llama3_8b",
"llama3_8b_16k",
"llama3_8b_64k",
Expand Down

0 comments on commit 3134cd0

Please sign in to comment.