Skip to content

Commit

Permalink
Fix peft inference (#11568)
Browse files Browse the repository at this point in the history
* fix peft inference (trainer not attached)

Signed-off-by: Chen Cui <[email protected]>

* enable greedy generation

Signed-off-by: Chen Cui <[email protected]>

* add ci test for PEFT inference

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* typo

Signed-off-by: Chen Cui <[email protected]>

* fix test

Signed-off-by: Chen Cui <[email protected]>

* handle remove_special_tokens

Signed-off-by: Chen Cui <[email protected]>

* move llama3configci to common file

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* incoming commit

Signed-off-by: Chen Cui <[email protected]>

* address comment

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
  • Loading branch information
cuichenx and cuichenx authored Dec 18, 2024
1 parent b7478b6 commit b74b6ec
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 52 deletions.
72 changes: 45 additions & 27 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4254,7 +4254,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4264,7 +4264,7 @@ jobs:
--mbs 1
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4283,7 +4283,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4293,7 +4293,7 @@ jobs:
--mbs 2
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4312,7 +4312,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4322,7 +4322,7 @@ jobs:
--mbs 2
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4341,7 +4341,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4351,7 +4351,7 @@ jobs:
--mbs 2
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4370,7 +4370,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4380,7 +4380,7 @@ jobs:
--mbs 1 --packed
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4399,7 +4399,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4409,7 +4409,7 @@ jobs:
--mbs 1
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4428,7 +4428,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4438,7 +4438,7 @@ jobs:
--mbs 2
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4457,7 +4457,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4467,7 +4467,7 @@ jobs:
--mbs 2
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4486,7 +4486,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4496,7 +4496,7 @@ jobs:
--mbs 2
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4514,7 +4514,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4524,7 +4524,7 @@ jobs:
--mbs 1 --packed
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4542,7 +4542,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4552,7 +4552,7 @@ jobs:
--mbs 1 --packed
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4569,7 +4569,7 @@ jobs:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4579,7 +4579,7 @@ jobs:
--mbs 1 --packed
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4597,7 +4597,7 @@ jobs:
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand All @@ -4608,7 +4608,7 @@ jobs:
--chat_dataset_path /home/TestData/nemo2_data/chat
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--restore_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
Expand Down Expand Up @@ -4702,9 +4702,26 @@ jobs:
SCRIPT: |
python tests/collections/llm/peft/lora_merge.py \
--lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint/ \
--lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint_v2/ \
--output_path=/tmp/nemo2_lora_merge/${{ github.run_id }}
L2_NEMO_2_LoRA_Inference:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NEMO_2_LoRA_Inference') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
python scripts/llm/generate.py \
--model_path /home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint_v2/ \
--tp 1 \
--pp 1 \
--devices 1 \
--top_p 0.0 \
--top_k 1 \
--num_tokens_to_generate 3
L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4900,6 +4917,7 @@ jobs:
- L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1
- L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1
- L2_NEMO_2_LoRA_MERGE
- L2_NEMO_2_LoRA_Inference
- L2_NeMo_2_Mixtral_Pretraining
- L2_PTQ_Llama2_FP8
- L2_Community_LLM_Checkpoints_tests_Llama3
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
from pathlib import Path
from typing import Optional, Union
Expand Down Expand Up @@ -61,7 +61,10 @@ def detokenize(self, tokens, remove_special_tokens=False):
Returns:
str: The detokenized string.
"""
return self.tokenizer.ids_to_text(tokens, remove_special_tokens)
if 'remove_special_tokens' in inspect.signature(self.tokenizer.ids_to_text).parameters:
return self.tokenizer.ids_to_text(tokens, remove_special_tokens)
else:
return self.tokenizer.ids_to_text(tokens)

def tokenize(self, prompt):
"""
Expand Down
5 changes: 3 additions & 2 deletions nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemo.lightning.megatron_parallel import MegatronParallel
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.lightning.pytorch.utils import is_trainer_attached
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO

Expand Down Expand Up @@ -105,7 +106,7 @@ def __call__(self, model: nn.Module) -> nn.Module:
else:
model.walk(self.transform)

if hasattr(model, "trainer") and model.trainer.state.fn != TrainerFn.FITTING:
if is_trainer_attached(model) and model.trainer.state.fn != TrainerFn.FITTING:
self.freeze_model(model)
return model

Expand All @@ -128,7 +129,7 @@ def freeze_model(self, model: nn.Module) -> None:
model.module.freeze()
else:
model.freeze()
if hasattr(model, "trainer") and model.trainer.state.fn == TrainerFn.FITTING:
if is_trainer_attached(model) and model.trainer.state.fn == TrainerFn.FITTING:
model.train(mode=True)

def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
Expand Down
12 changes: 12 additions & 0 deletions nemo/lightning/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import lightning.pytorch as pl
import torch


Expand Down Expand Up @@ -55,3 +56,14 @@ def dtype_from_hf(config):
return dtype_from_str(torch_dtype)
else:
raise ValueError("torch_dtype is not of type str/torch.dtype")


def is_trainer_attached(model: pl.LightningModule):
"""
Returns true if trainer is attached to a model
"""
try:
trainer = model.trainer
return True
except (AttributeError, RuntimeError):
return False
11 changes: 10 additions & 1 deletion scripts/llm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def get_args():
default=0.95,
help="""top_p to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--top_k",
type=float,
default=0,
help="""top_k to be used in megatron.core.inference.common_inference_params.CommonInferenceParams""",
)
parser.add_argument(
"--num_tokens_to_generate",
type=int,
Expand Down Expand Up @@ -118,7 +124,10 @@ def get_args():
prompts=prompts,
trainer=trainer,
inference_params=CommonInferenceParams(
temperature=args.temperature, top_p=args.top_p, num_tokens_to_generate=args.num_tokens_to_generate
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
num_tokens_to_generate=args.num_tokens_to_generate,
),
text_only=True,
)
Expand Down
10 changes: 10 additions & 0 deletions tests/collections/llm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from dataclasses import dataclass

import lightning.pytorch as pl
import nemo_run as run
Expand Down Expand Up @@ -191,3 +192,12 @@ def verify_precision(tensor: torch.Tensor) -> None:
assert tensor.dtype == precision

return verify_precision


@dataclass
class Llama3ConfigCI(llm.Llama3Config8B):
seq_length: int = 2048
num_layers: int = 2
hidden_size: int = 768
ffn_hidden_size: int = 3072
num_attention_heads: int = 8
Loading

0 comments on commit b74b6ec

Please sign in to comment.