Skip to content

Commit

Permalink
Release/2.7 (PaddlePaddle#8145)
Browse files Browse the repository at this point in the history
* add rslora & lora+

* remove print

* reformat

* update

* fix bug

* add rslora+ ci

* remove magic number

* update

* empty

---------

Co-authored-by: lugimzzz <[email protected]>
  • Loading branch information
wtmlon and lugimzzz authored Mar 19, 2024
1 parent 0a8c0f9 commit 18072a2
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 7 deletions.
2 changes: 2 additions & 0 deletions llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class ModelArgument:
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})

# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
Expand Down
4 changes: 3 additions & 1 deletion llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,9 @@ def neft_post_hook(module, input, output):
lora_config = LoRAConfig(
target_modules=target_modules,
r=model_args.lora_rank,
lora_alpha=2 * model_args.lora_rank,
lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4,
rslora=model_args.rslora,
lora_plus_scale=model_args.lora_plus_scale,
merge_weights=False,
tensor_parallel_degree=training_args.tensor_parallel_degree,
dtype=dtype,
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class LoRAConfig:
},
)
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
Expand Down
39 changes: 33 additions & 6 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
rslora: bool = False,
lora_plus_scale: float = 1.0,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -62,9 +64,16 @@ def __init__(
shape=[r, out_features],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)
self.scaling = self.lora_alpha / self.r

if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
Expand Down Expand Up @@ -104,6 +113,8 @@ def __init__(
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
**kwargs
):
Expand Down Expand Up @@ -137,12 +148,19 @@ def __init__(
shape=[r, self.out_features],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_A.is_distributed = True
self.lora_A.split_axis = 0
self.lora_B.is_distributed = False
self.scaling = self.lora_alpha / self.r
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
Expand Down Expand Up @@ -208,6 +226,8 @@ def __init__(
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
**kwargs
Expand Down Expand Up @@ -241,11 +261,18 @@ def __init__(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
self.scaling = self.lora_alpha / self.r
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
Expand Down
6 changes: 6 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
merge_weights=lora_config.merge_weights,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
bias_attr=False if module.bias is None else None,
)
if isinstance(module, nn.Conv2D):
Expand Down Expand Up @@ -327,6 +329,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
merge_weights=lora_config.merge_weights,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
Expand All @@ -352,6 +356,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
merge_weights=lora_config.merge_weights,
)
# Lora column parallel will spilt lora A matrix
Expand Down
45 changes: 45 additions & 0 deletions tests/fixtures/llm/lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,51 @@ lora:
baichuan:
model_name_or_path: __internal_testing__/tiny-fused-baichuan

rslora_plus:
base:
dataset_name_or_path: "./data"
per_device_train_batch_size: 4
gradient_accumulation_steps: 4
per_device_eval_batch_size: 8
eval_accumulation_steps: 16
num_train_epochs: 3
learning_rate: 3e-04
warmup_steps: 30
logging_steps: 1
evaluation_strategy: "epoch"
save_strategy: "epoch"
src_length: 1024
max_length: 2048
fp16: true
fp16_opt_level: "O2"
do_train: true
do_eval: true
disable_tqdm: true
load_best_model_at_end: true
eval_with_do_generation: false
metric_for_best_model: "accuracy"
recompute: true
save_total_limit: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
lora: true
lora_plus_scale: 4
rslora: true

default:
llama:
model_name_or_path: __internal_testing__/tiny-random-llama
chatglm:
model_name_or_path: __internal_testing__/tiny-fused-chatglm
chatglm2:
model_name_or_path: __internal_testing__/tiny-fused-chatglm2
bloom:
model_name_or_path: __internal_testing__/tiny-fused-bloom
qwen:
model_name_or_path: __internal_testing__/tiny-fused-qwen
baichuan:
model_name_or_path: __internal_testing__/tiny-fused-baichuan

inference-predict:
default:
mode: dynamic
Expand Down
29 changes: 29 additions & 0 deletions tests/llm/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,35 @@ def test_lora(self):

self.run_predictor({"inference_model": False})

def test_rslora_plus(self):
self.disable_static()
paddle.set_default_dtype("float32")

lora_config = load_test_config(self.config_path, "rslora_plus", self.model_dir)
lora_config["output_dir"] = self.output_dir
lora_config["dataset_name_or_path"] = self.data_dir

with argv_context_guard(lora_config):
from finetune_generation import main

main()

# merge weights
merge_lora_weights_config = {
"lora_path": lora_config["output_dir"],
"merge_lora_model_path": lora_config["output_dir"],
}
with argv_context_guard(merge_lora_weights_config):
from merge_lora_params import merge

merge()

# TODO(wj-Mcat): disable chatglm2 test temporarily
if self.model_dir not in ["qwen", "baichuan", "chatglm2"]:
self.run_predictor({"inference_model": True})

self.run_predictor({"inference_model": False})


# @parameterized_class(
# ["model_dir"],
Expand Down

0 comments on commit 18072a2

Please sign in to comment.