Skip to content

Commit

Permalink
[usability] Add hymba lora target (#924) (#925)
Browse files Browse the repository at this point in the history
* [usability] add hymba lora target

* [usability] add hymba lora target

* typo fix

---------

Co-authored-by: Yizhen Jia <[email protected]>
Co-authored-by: YizhenJia <[email protected]>
  • Loading branch information
3 people authored Dec 24, 2024
1 parent 4dffd02 commit a6c5ae2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
11 changes: 9 additions & 2 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ class ModelArguments:
metadata={
"help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."},
)
lora_target_modules: List[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name",}
lora_target_modules: str = field(
default=None, metadata={"help": "Model modules to apply LoRA to. Use comma to separate multiple modules."}
)
lora_dropout: float = field(
default=0.1,
Expand Down Expand Up @@ -364,6 +364,9 @@ def __post_init__(self):
if not is_flash_attn_available():
self.use_flash_attention = False
logger.warning("Flash attention is not available in the current environment. Disabling flash attention.")

if self.lora_target_modules is not None:
self.lora_target_modules: List[str] = split_args(self.lora_target_modules)


@dataclass
Expand Down Expand Up @@ -1464,3 +1467,7 @@ class AutoArguments:

def get_pipeline_args_class(pipeline_name: str):
return PIPELINE_ARGUMENT_MAPPING[pipeline_name]


def split_args(args):
return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args
11 changes: 8 additions & 3 deletions src/lmflow/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,17 @@
DEFAULT_IM_END_TOKEN = "<im_end>"

# Lora
# NOTE: Be careful, when passing lora_target_modules through arg parser, the
# value should be like'--lora_target_modules q_proj, v_proj \', while specifying
# here, it should be in list format.
# NOTE: This work as a mapping for those models that `peft` library doesn't support yet, and will be
# overwritten by peft.utils.constants.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
# if the model is supported (see hf_model_mixin.py).
# NOTE: When passing lora_target_modules through arg parser, the
# value should be a string. Using commas to separate the module names, e.g.
# "--lora_target_modules 'q_proj, v_proj'".
# However, when specifying here, they should be lists.
LMFLOW_LORA_TARGET_MODULES_MAPPING = {
'qwen2': ["q_proj", "v_proj"],
'internlm2': ["wqkv"],
'hymba': ["x_proj.0", "in_proj", "out_proj", "dt_proj.0"]
}

# vllm inference
Expand Down

0 comments on commit a6c5ae2

Please sign in to comment.