From 8ea793e9c74f8fe02cf3175a9e73d8dc6c6a5329 Mon Sep 17 00:00:00 2001 From: lugimzzz Date: Wed, 13 Mar 2024 15:23:05 +0800 Subject: [PATCH 1/9] add rslora & lora+ --- llm/argument.py | 3 + llm/finetune_generation.py | 3 + paddlenlp/peft/lora/lora_config.py | 3 + paddlenlp/peft/lora/lora_layers.py | 97 +++++++++++++++++++++++------- paddlenlp/peft/lora/lora_model.py | 9 +++ 5 files changed, 94 insertions(+), 21 deletions(-) diff --git a/llm/argument.py b/llm/argument.py index 50add643675e..e2e6d194d9db 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -126,6 +126,9 @@ class ModelArgument: lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"}) lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"}) + rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) + lora_plus: bool = field(default=False, metadata={"help": "Whether to use LoRA+ technique"}) + lora_B_scale: int = field(default=16, metadata={"help": "Lora B scale in LoRA+ technique"}) # prefix tuning related parameters prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index be9b8f3cb4d1..d1550e315592 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -419,6 +419,9 @@ def neft_post_hook(module, input, output): target_modules=target_modules, r=model_args.lora_rank, lora_alpha=2 * model_args.lora_rank, + rslora=model_args.rslora, + lora_plus=model_args.lora_plus, + lora_B_scale=model_args.lora_B_scale, merge_weights=False, tensor_parallel_degree=training_args.tensor_parallel_degree, dtype=dtype, diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index d9952a5f02d9..7f6814295627 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -72,6 +72,9 @@ class LoRAConfig: }, ) do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"}) + rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) + lora_plus: bool = field(default=False, metadata={"help": "Whether to use LoRA+"}) + lora_B_scale: int = field(default=16, metadata={"help": "Lora B scale in LoRA+"}) base_model_name_or_path: Optional[str] = field( default=None, metadata={"help": "The name of the base model to use."} ) diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index 5ef19eacf817..d0334f61c0c4 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -35,6 +35,9 @@ def __init__( lora_alpha: int = 1, lora_dropout: float = 0.0, merge_weights: bool = True, + rslora: bool = False, + lora_plus: bool = False, + lora_B_scale: int = 16, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) @@ -58,13 +61,29 @@ def __init__( is_bias=False, default_initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu"), ) - self.lora_B = self.create_parameter( - shape=[r, out_features], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) - self.scaling = self.lora_alpha / self.r + if lora_plus: + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_B_scale, + ), + ) + print("Using LORA+") + else: + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = 4.0 / math.sqrt(self.r) + print(f"Using RSLORA scaling {self.scaling}") # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -104,6 +123,9 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus: bool = False, + lora_B_scale: int = 16, merge_weights: bool = True, **kwargs ): @@ -133,16 +155,31 @@ def __init__( initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") ), ) - self.lora_B = self.create_parameter( - shape=[r, self.out_features], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) + if lora_plus: + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_B_scale, + ), + ) + else: + self.lora_B = self.create_parameter( + shape=[r, self.out_features], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) self.lora_A.is_distributed = True self.lora_A.split_axis = 0 self.lora_B.is_distributed = False - self.scaling = self.lora_alpha / self.r + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = 4.0 / math.sqrt(self.r) + print(f"Using RSLORA scaling {self.scaling}") # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -208,6 +245,9 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus: bool = False, + lora_B_scale: int = 16, merge_weights: bool = True, lora_A_weight_attr: Optional[paddle.ParamAttr] = None, **kwargs @@ -237,15 +277,30 @@ def __init__( attr=lora_A_weight_attr, ) self.lora_A.is_distributed = False - self.lora_B = self.create_parameter( - shape=[r, self.output_size_per_partition], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) + if lora_plus: + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_B_scale, + ), + ) + else: + self.lora_B = self.create_parameter( + shape=[r, self.output_size_per_partition], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) self.lora_B.is_distributed = True self.lora_B.split_axis = 1 - self.scaling = self.lora_alpha / self.r + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = 4.0 / math.sqrt(self.r) + print(f"Using RSLORA scaling {self.scaling}") # Freezing the pre-trained weight matrix self.weight.stop_gradient = True diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 2bad88e01771..c9545caa4b67 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -382,6 +382,9 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, merge_weights=lora_config.merge_weights, + rslora=lora_config.rslora, + lora_plus=lora_config.lora_plus, + lora_B_scale=lora_config.lora_B_scale, bias_attr=False if module.bias is None else None, ) if isinstance(module, nn.Conv2D): @@ -412,6 +415,9 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) r=lora_config.r, lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus=lora_config.lora_plus, + lora_B_scale=lora_config.lora_B_scale, merge_weights=lora_config.merge_weights, lora_A_weight_attr=paddle.ParamAttr( initializer=nn.initializer.KaimingUniform( @@ -437,6 +443,9 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) r=lora_config.r, lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus=lora_config.lora_plus, + lora_B_scale=lora_config.lora_B_scale, merge_weights=lora_config.merge_weights, ) # Lora column parallel will spilt lora A matrix From 5384e5b3b6b022e42b0b5fe84162b45b361e2324 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Wed, 13 Mar 2024 15:40:07 +0800 Subject: [PATCH 2/9] remove print --- paddlenlp/peft/lora/lora_layers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index d0334f61c0c4..3ac7ffa6b168 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -71,7 +71,6 @@ def __init__( learning_rate=lora_B_scale, ), ) - print("Using LORA+") else: self.lora_B = self.create_parameter( shape=[r, out_features], @@ -83,7 +82,6 @@ def __init__( self.scaling = self.lora_alpha / self.r else: self.scaling = 4.0 / math.sqrt(self.r) - print(f"Using RSLORA scaling {self.scaling}") # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -179,7 +177,6 @@ def __init__( self.scaling = self.lora_alpha / self.r else: self.scaling = 4.0 / math.sqrt(self.r) - print(f"Using RSLORA scaling {self.scaling}") # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -300,7 +297,6 @@ def __init__( self.scaling = self.lora_alpha / self.r else: self.scaling = 4.0 / math.sqrt(self.r) - print(f"Using RSLORA scaling {self.scaling}") # Freezing the pre-trained weight matrix self.weight.stop_gradient = True From 38efb5b59960cc579956d36d4ae5d9720d3c7538 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Wed, 13 Mar 2024 15:52:06 +0800 Subject: [PATCH 3/9] reformat --- llm/argument.py | 3 +- llm/finetune_generation.py | 3 +- paddlenlp/peft/lora/lora_layers.py | 90 +++++++++++------------------- paddlenlp/peft/lora/lora_model.py | 9 +-- 4 files changed, 38 insertions(+), 67 deletions(-) diff --git a/llm/argument.py b/llm/argument.py index e2e6d194d9db..4a23db0b8222 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -127,8 +127,7 @@ class ModelArgument: lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"}) rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) - lora_plus: bool = field(default=False, metadata={"help": "Whether to use LoRA+ technique"}) - lora_B_scale: int = field(default=16, metadata={"help": "Lora B scale in LoRA+ technique"}) + lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) # prefix tuning related parameters prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index d1550e315592..6436af0b321e 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -420,8 +420,7 @@ def neft_post_hook(module, input, output): r=model_args.lora_rank, lora_alpha=2 * model_args.lora_rank, rslora=model_args.rslora, - lora_plus=model_args.lora_plus, - lora_B_scale=model_args.lora_B_scale, + lora_plus_scale=model_args.lora_plus_scale, merge_weights=False, tensor_parallel_degree=training_args.tensor_parallel_degree, dtype=dtype, diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index 3ac7ffa6b168..de2ac3c58302 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -36,8 +36,7 @@ def __init__( lora_dropout: float = 0.0, merge_weights: bool = True, rslora: bool = False, - lora_plus: bool = False, - lora_B_scale: int = 16, + lora_plus_scale: float = 1.0, **kwargs ): nn.Linear.__init__(self, in_features, out_features, **kwargs) @@ -61,23 +60,16 @@ def __init__( is_bias=False, default_initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu"), ) - if lora_plus: - self.lora_B = self.create_parameter( - shape=[r, out_features], - dtype=self._dtype, - is_bias=False, - attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.Constant(value=0.0), - learning_rate=lora_B_scale, - ), - ) - else: - self.lora_B = self.create_parameter( - shape=[r, out_features], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + if not rslora: self.scaling = self.lora_alpha / self.r else: @@ -122,8 +114,7 @@ def __init__( lora_alpha: int = 1, lora_dropout: float = 0.0, rslora: bool = False, - lora_plus: bool = False, - lora_B_scale: int = 16, + lora_plus_scale: float = 1.0, merge_weights: bool = True, **kwargs ): @@ -153,23 +144,16 @@ def __init__( initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") ), ) - if lora_plus: - self.lora_B = self.create_parameter( - shape=[r, out_features], - dtype=self._dtype, - is_bias=False, - attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.Constant(value=0.0), - learning_rate=lora_B_scale, - ), - ) - else: - self.lora_B = self.create_parameter( - shape=[r, self.out_features], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + self.lora_A.is_distributed = True self.lora_A.split_axis = 0 self.lora_B.is_distributed = False @@ -243,8 +227,7 @@ def __init__( lora_alpha: int = 1, lora_dropout: float = 0.0, rslora: bool = False, - lora_plus: bool = False, - lora_B_scale: int = 16, + lora_plus_scale: float = 1.0, merge_weights: bool = True, lora_A_weight_attr: Optional[paddle.ParamAttr] = None, **kwargs @@ -274,23 +257,16 @@ def __init__( attr=lora_A_weight_attr, ) self.lora_A.is_distributed = False - if lora_plus: - self.lora_B = self.create_parameter( - shape=[r, out_features], - dtype=self._dtype, - is_bias=False, - attr = paddle.ParamAttr( - initializer=paddle.nn.initializer.Constant(value=0.0), - learning_rate=lora_B_scale, - ), - ) - else: - self.lora_B = self.create_parameter( - shape=[r, self.output_size_per_partition], - dtype=self._dtype, - is_bias=False, - default_initializer=nn.initializer.Constant(value=0.0), - ) + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + self.lora_B.is_distributed = True self.lora_B.split_axis = 1 if not rslora: diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index c9545caa4b67..b76ae3306a4f 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -383,8 +383,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) lora_dropout=lora_config.lora_dropout, merge_weights=lora_config.merge_weights, rslora=lora_config.rslora, - lora_plus=lora_config.lora_plus, - lora_B_scale=lora_config.lora_B_scale, + lora_plus_scale=lora_config.lora_plus_scale, bias_attr=False if module.bias is None else None, ) if isinstance(module, nn.Conv2D): @@ -416,8 +415,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, rslora=lora_config.rslora, - lora_plus=lora_config.lora_plus, - lora_B_scale=lora_config.lora_B_scale, + lora_plus_scale=lora_config.lora_plus_scale, merge_weights=lora_config.merge_weights, lora_A_weight_attr=paddle.ParamAttr( initializer=nn.initializer.KaimingUniform( @@ -444,8 +442,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout, rslora=lora_config.rslora, - lora_plus=lora_config.lora_plus, - lora_B_scale=lora_config.lora_B_scale, + lora_plus_scale=lora_config.lora_plus_scale, merge_weights=lora_config.merge_weights, ) # Lora column parallel will spilt lora A matrix From 617e8de3adf8bc4c0337162b99246c55beb7d868 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Wed, 13 Mar 2024 15:54:12 +0800 Subject: [PATCH 4/9] update --- paddlenlp/peft/lora/lora_config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index 7f6814295627..2434d369da5e 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -73,8 +73,7 @@ class LoRAConfig: ) do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"}) rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) - lora_plus: bool = field(default=False, metadata={"help": "Whether to use LoRA+"}) - lora_B_scale: int = field(default=16, metadata={"help": "Lora B scale in LoRA+"}) + lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"}) base_model_name_or_path: Optional[str] = field( default=None, metadata={"help": "The name of the base model to use."} ) From 960b165453f877ed3edec1dfe922508f0c54cec6 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Wed, 13 Mar 2024 16:46:27 +0800 Subject: [PATCH 5/9] fix bug --- paddlenlp/peft/lora/lora_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index de2ac3c58302..9508d9d039b9 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -145,7 +145,7 @@ def __init__( ), ) self.lora_B = self.create_parameter( - shape=[r, out_features], + shape=[r, self.out_features], dtype=self._dtype, is_bias=False, attr=paddle.ParamAttr( @@ -258,7 +258,7 @@ def __init__( ) self.lora_A.is_distributed = False self.lora_B = self.create_parameter( - shape=[r, out_features], + shape=[r, self.output_size_per_partition], dtype=self._dtype, is_bias=False, attr=paddle.ParamAttr( From c1f346dc615f5fddcb6d926080473caf216d6fc8 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Fri, 15 Mar 2024 11:34:41 +0800 Subject: [PATCH 6/9] add rslora+ ci --- tests/fixtures/llm/lora.yaml | 45 ++++++++++++++++++++++++++++++++++++ tests/llm/test_lora.py | 29 +++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/tests/fixtures/llm/lora.yaml b/tests/fixtures/llm/lora.yaml index 6a2cbfa732c7..bf5db5efd979 100644 --- a/tests/fixtures/llm/lora.yaml +++ b/tests/fixtures/llm/lora.yaml @@ -41,6 +41,51 @@ lora: baichuan: model_name_or_path: __internal_testing__/tiny-fused-baichuan +rslora_plus: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-04 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + lora: true + lora_plus_scale: 4 + rslora: true + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama + chatglm: + model_name_or_path: __internal_testing__/tiny-fused-chatglm + chatglm2: + model_name_or_path: __internal_testing__/tiny-fused-chatglm2 + bloom: + model_name_or_path: __internal_testing__/tiny-fused-bloom + qwen: + model_name_or_path: __internal_testing__/tiny-fused-qwen + baichuan: + model_name_or_path: __internal_testing__/tiny-fused-baichuan + inference-predict: default: mode: dynamic diff --git a/tests/llm/test_lora.py b/tests/llm/test_lora.py index 138c2ccf699a..d4bec137e8c6 100644 --- a/tests/llm/test_lora.py +++ b/tests/llm/test_lora.py @@ -79,6 +79,35 @@ def test_lora(self): self.run_predictor({"inference_model": False}) + def test_rslora_plus(self): + self.disable_static() + paddle.set_default_dtype("float32") + + lora_config = load_test_config(self.config_path, "rslora_plus", self.model_dir) + lora_config["output_dir"] = self.output_dir + lora_config["dataset_name_or_path"] = self.data_dir + + with argv_context_guard(lora_config): + from finetune_generation import main + + main() + + # merge weights + merge_lora_weights_config = { + "lora_path": lora_config["output_dir"], + "merge_lora_model_path": lora_config["output_dir"], + } + with argv_context_guard(merge_lora_weights_config): + from merge_lora_params import merge + + merge() + + # TODO(wj-Mcat): disable chatglm2 test temporarily + if self.model_dir not in ["qwen", "baichuan", "chatglm2"]: + self.run_predictor({"inference_model": True}) + + self.run_predictor({"inference_model": False}) + # @parameterized_class( # ["model_dir"], From 568c437142ae82b54fb12ed5e4ea07facd5e7be9 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Mon, 18 Mar 2024 18:56:15 +0800 Subject: [PATCH 7/9] remove magic number --- llm/finetune_generation.py | 2 +- paddlenlp/peft/lora/lora_layers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index 6436af0b321e..25e6a8310a78 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -418,7 +418,7 @@ def neft_post_hook(module, input, output): lora_config = LoRAConfig( target_modules=target_modules, r=model_args.lora_rank, - lora_alpha=2 * model_args.lora_rank, + lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4, rslora=model_args.rslora, lora_plus_scale=model_args.lora_plus_scale, merge_weights=False, diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index 9508d9d039b9..bd831105244d 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -73,7 +73,7 @@ def __init__( if not rslora: self.scaling = self.lora_alpha / self.r else: - self.scaling = 4.0 / math.sqrt(self.r) + self.scaling = self.lora_alpha / math.sqrt(self.r) # Freezing the pre-trained weight matrix self.weight.stop_gradient = True From 5a7308242fe14ec1d18fbc1554fc799055cb4c17 Mon Sep 17 00:00:00 2001 From: wtmlon Date: Mon, 18 Mar 2024 19:00:46 +0800 Subject: [PATCH 8/9] update --- paddlenlp/peft/lora/lora_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index bd831105244d..ae38f47825e4 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -160,7 +160,7 @@ def __init__( if not rslora: self.scaling = self.lora_alpha / self.r else: - self.scaling = 4.0 / math.sqrt(self.r) + self.scaling = self.lora_alpha / math.sqrt(self.r) # Freezing the pre-trained weight matrix self.weight.stop_gradient = True @@ -272,7 +272,7 @@ def __init__( if not rslora: self.scaling = self.lora_alpha / self.r else: - self.scaling = 4.0 / math.sqrt(self.r) + self.scaling = self.lora_alpha / math.sqrt(self.r) # Freezing the pre-trained weight matrix self.weight.stop_gradient = True From 4f64001e2899ce76ac606a04a4d717d45daa704b Mon Sep 17 00:00:00 2001 From: wtmlon Date: Tue, 19 Mar 2024 11:19:37 +0800 Subject: [PATCH 9/9] empty