diff --git a/llm/argument.py b/llm/argument.py index fcec69a93dea..aabf4a5aebda 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -126,6 +126,8 @@ class ModelArgument: lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"}) lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"}) + rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) + lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) # prefix tuning related parameters prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index d9a54a0e6226..920aabbbcd76 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -412,7 +412,9 @@ def neft_post_hook(module, input, output): lora_config = LoRAConfig( target_modules=target_modules, r=model_args.lora_rank, - lora_alpha=2 * model_args.lora_rank, + lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4, + rslora=model_args.rslora, + lora_plus_scale=model_args.lora_plus_scale, merge_weights=False, tensor_parallel_degree=training_args.tensor_parallel_degree, dtype=dtype, diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index d9952a5f02d9..2434d369da5e 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -72,6 +72,8 @@ class LoRAConfig: }, ) do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"}) + rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) + lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"}) base_model_name_or_path: Optional[str] = field( default=None, metadata={"help": "The name of the base model to use."} ) diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index 5ef19eacf817..ae38f47825e4 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -35,6 +35,8 @@ def __init__( lora_alpha: int = 1, lora_dropout: float = 0.0, merge_weights: bool = True, + rslora: bool = False, + lora_plus_scale: float = 1.0, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) @@ -62,9 +64,16 @@ def __init__( shape=[r, out_features], dtype=self._dtype, is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), ) - self.scaling = self.lora_alpha / self.r + + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -104,6 +113,8 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus_scale: float = 1.0, merge_weights: bool = True, **kwargs ): @@ -137,12 +148,19 @@ def __init__( shape=[r, self.out_features], dtype=self._dtype, is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), ) + self.lora_A.is_distributed = True self.lora_A.split_axis = 0 self.lora_B.is_distributed = False - self.scaling = self.lora_alpha / self.r + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -208,6 +226,8 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus_scale: float = 1.0, merge_weights: bool = True, lora_A_weight_attr: Optional[paddle.ParamAttr] = None, **kwargs @@ -241,11 +261,18 @@ def __init__( shape=[r, self.output_size_per_partition], dtype=self._dtype, is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), ) + self.lora_B.is_distributed = True self.lora_B.split_axis = 1 - self.scaling = self.lora_alpha / self.r + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) # Freezing the pre-trained weight matrix self.weight.stop_gradient = True diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index ccea3c006a0f..83c87433c662 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -297,6 +297,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, merge_weights=lora_config.merge_weights, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, bias_attr=False if module.bias is None else None, ) if isinstance(module, nn.Conv2D): @@ -327,6 +329,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) r=lora_config.r, lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, merge_weights=lora_config.merge_weights, lora_A_weight_attr=paddle.ParamAttr( initializer=nn.initializer.KaimingUniform( @@ -352,6 +356,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) r=lora_config.r, lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, merge_weights=lora_config.merge_weights, ) # Lora column parallel will spilt lora A matrix diff --git a/tests/fixtures/llm/lora.yaml b/tests/fixtures/llm/lora.yaml index 6a2cbfa732c7..bf5db5efd979 100644 --- a/tests/fixtures/llm/lora.yaml +++ b/tests/fixtures/llm/lora.yaml @@ -41,6 +41,51 @@ lora: baichuan: model_name_or_path: __internal_testing__/tiny-fused-baichuan +rslora_plus: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-04 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + lora: true + lora_plus_scale: 4 + rslora: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + inference-predict: default: mode: dynamic diff --git a/tests/llm/test_lora.py b/tests/llm/test_lora.py index 138c2ccf699a..d4bec137e8c6 100644 --- a/tests/llm/test_lora.py +++ b/tests/llm/test_lora.py @@ -79,6 +79,35 @@ def test_lora(self): self.run_predictor({"inference_model": False}) + def test_rslora_plus(self): + self.disable_static() + paddle.set_default_dtype("float32") + + lora_config = load_test_config(self.config_path, "rslora_plus", self.model_dir) + lora_config["output_dir"] = self.output_dir + lora_config["dataset_name_or_path"] = self.data_dir + + with argv_context_guard(lora_config): + from finetune_generation import main + + main() + + # merge weights + merge_lora_weights_config = { + "lora_path": lora_config["output_dir"], + "merge_lora_model_path": lora_config["output_dir"], + } + with argv_context_guard(merge_lora_weights_config): + from merge_lora_params import merge + + merge() + + # TODO(wj-Mcat): disable chatglm2 test temporarily + if self.model_dir not in ["qwen", "baichuan", "chatglm2"]: + self.run_predictor({"inference_model": True}) + + self.run_predictor({"inference_model": False}) + # @parameterized_class( # ["model_dir"],