diff --git a/llm/argument.py b/llm/argument.py index 4a23db0b8222..64e736873ca2 100644 --- a/llm/argument.py +++ b/llm/argument.py @@ -126,6 +126,12 @@ class ModelArgument: lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"}) lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"}) + use_quick_lora: bool = field( + default=False, + metadata={ + "help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0." + }, + ) rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"}) lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"}) diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index 25e6a8310a78..c3396230c0a0 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -112,6 +112,7 @@ def main(): weight_double_quant=model_args.weight_double_quant, weight_double_quant_block_size=model_args.weight_double_quant_block_size, ) + if training_args.pipeline_parallel_degree > 1: if data_args.eval_with_do_generation and training_args.do_eval: raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.") @@ -426,10 +427,12 @@ def neft_post_hook(module, input, output): dtype=dtype, do_qat=quant_args.do_qat, base_model_name_or_path=model_args.model_name_or_path, + use_quick_lora=model_args.use_quick_lora, ) model = LoRAModel(model, lora_config) else: model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) + model.print_trainable_parameters() def compute_metrics_do_generation(eval_preds): diff --git a/paddlenlp/peft/lora/lora_config.py b/paddlenlp/peft/lora/lora_config.py index 2434d369da5e..8598c33f8622 100644 --- a/paddlenlp/peft/lora/lora_config.py +++ b/paddlenlp/peft/lora/lora_config.py @@ -18,6 +18,7 @@ from typing import List, Optional, Union from ...utils.env import LORA_CONFIG_NAME +from ...utils.log import logger @dataclass @@ -77,6 +78,20 @@ class LoRAConfig: base_model_name_or_path: Optional[str] = field( default=None, metadata={"help": "The name of the base model to use."} ) + use_quick_lora: bool = field( + default=False, + metadata={ + "help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0." + }, + ) + + def __post_init__(self): + if self.use_quick_lora and self.lora_dropout > 0: + logger.warning( + "Quick LoRa is enabled, but lora_dropout is set to a non-zero value. " + "We will automatically set `use_quick_lora` to `False` to avoid potential inconsistencies." + ) + self.use_quick_lora = False @property def __dict__(self): diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index c90e96aab4e2..dc64248a078b 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -25,6 +25,8 @@ RowParallelLinear, ) +from .lora_quick_layers import quick_lora + if "npu" in paddle.device.get_all_custom_device_type(): from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear else: @@ -42,6 +44,7 @@ def __init__( lora_alpha: int = 1, lora_dropout: float = 0.0, merge_weights: bool = True, + use_quick_lora: bool = False, rslora: bool = False, lora_plus_scale: float = 1.0, **kwargs @@ -84,6 +87,11 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + return self._use_quick_lora and self.training and not self.merged def train(self): super().train() @@ -102,9 +110,13 @@ def eval(self): self.merged = True def forward(self, input: paddle.Tensor, *args, **kwargs): - result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) - if not self.merged: - result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling + if self.use_quick_lora: + # Use the quick lora implementation + result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling) + else: + result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name) + if not self.merged: + result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling return result def extra_repr(self): @@ -123,6 +135,7 @@ def __init__( rslora: bool = False, lora_plus_scale: float = 1.0, merge_weights: bool = True, + use_quick_lora: bool = False, **kwargs ): RowParallelLinear.__init__(self, in_features, out_features, **kwargs) @@ -171,6 +184,11 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + return self._use_quick_lora and self.training and not self.merged def train(self): super().train() @@ -194,33 +212,52 @@ def forward(self, x: paddle.Tensor): else: input_mp = x - # x @ W : [bz, in_f / ws] ===> [bz, out_f] - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): - output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) - else: - result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name) - + if self.use_quick_lora: + # Use the quick lora implementation + result_mp = quick_lora( + input_mp, + self.lora_A, + self.lora_B, + self.weight, + self.bias, + self.scaling, + is_row=True, + group=self.model_parallel_group, + world_size=self.world_size, + ) output = mp_ops._mp_allreduce( result_mp, group=self.model_parallel_group, use_calc_stream=True, use_model_parallel=True, ) + else: + # x @ W : [bz, in_f / ws] ===> [bz, out_f] + if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): + output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) + else: + result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name) + output = mp_ops._mp_allreduce( + result_mp, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + ) - if not self.merged: - # x @ A: [bz, in_f/ ws] ===> [bz, r] - input_mp = self.lora_dropout(input_mp) @ self.lora_A - # all reduce to keep Lora B's gradient on different gpu consistent - input_dup = mp_ops._mp_allreduce( - input_mp, - group=self.model_parallel_group, - use_calc_stream=True, - use_model_parallel=True, - ) - # @ B: [bz, r] ===> [bz, out_f] - delta_mp = (input_dup @ self.lora_B) * self.scaling - output += delta_mp - output = output + self.bias if self.bias is not None else output + if not self.merged: + # x @ A: [bz, in_f/ ws] ===> [bz, r] + input_mp = self.lora_dropout(input_mp) @ self.lora_A + # all reduce to keep Lora B's gradient on different gpu consistent + input_dup = mp_ops._mp_allreduce( + input_mp, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + ) + # @ B: [bz, r] ===> [bz, out_f] + delta_mp = (input_dup @ self.lora_B) * self.scaling + output += delta_mp + output = output + self.bias if self.bias is not None else output return output def extra_repr(self): @@ -240,6 +277,7 @@ def __init__( lora_plus_scale: float = 1.0, merge_weights: bool = True, lora_A_weight_attr: Optional[paddle.ParamAttr] = None, + use_quick_lora: bool = False, **kwargs ): ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs) @@ -286,6 +324,11 @@ def __init__( # Freezing the pre-trained weight matrix self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + return self._use_quick_lora and self.training and not self.merged def train(self): super().train() @@ -304,22 +347,37 @@ def eval(self): self.merged = True def forward(self, input: paddle.Tensor): - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): - res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group) - result_mp = res_mp + self.bias + if self.use_quick_lora: + # Use the quick lora implementation + input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) if self.is_mp else input + result_mp = quick_lora( + input_mp, + self.lora_A, + self.lora_B, + self.weight, + self.bias, + self.scaling, + is_column=True, + group=self.model_parallel_group, + world_size=self.world_size, + ) else: - input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) - result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) - - if not self.merged: - input_a = self.lora_dropout(input) @ self.lora_A if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): - tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) - delta_mp = tmp * self.scaling + res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group) + result_mp = res_mp + self.bias else: - input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) - delta_mp = (input_a_mp @ self.lora_B) * self.scaling - result_mp += delta_mp + input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) + result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) + + if not self.merged: + input_a = self.lora_dropout(input) @ self.lora_A + if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): + tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) + delta_mp = tmp * self.scaling + else: + input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) + delta_mp = (input_a_mp @ self.lora_B) * self.scaling + result_mp += delta_mp if self.gather_output and self.is_mp: result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index b76ae3306a4f..30198dc710be 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -385,6 +385,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) rslora=lora_config.rslora, lora_plus_scale=lora_config.lora_plus_scale, bias_attr=False if module.bias is None else None, + use_quick_lora=lora_config.use_quick_lora, ) if isinstance(module, nn.Conv2D): lora_module = LoRAConv2D( @@ -422,6 +423,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) negative_slope=math.sqrt(5), nonlinearity="leaky_relu" ) ), + use_quick_lora=lora_config.use_quick_lora, ) # Lora column parallel will spilt lora B matrix self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) @@ -444,6 +446,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) rslora=lora_config.rslora, lora_plus_scale=lora_config.lora_plus_scale, merge_weights=lora_config.merge_weights, + use_quick_lora=lora_config.use_quick_lora, ) # Lora column parallel will spilt lora A matrix self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) diff --git a/paddlenlp/peft/lora/lora_quick_layers.py b/paddlenlp/peft/lora/lora_quick_layers.py new file mode 100644 index 000000000000..ab48069b7e65 --- /dev/null +++ b/paddlenlp/peft/lora/lora_quick_layers.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.communication.reduce import ReduceOp, _get_reduce_op +from paddle.distributed.fleet.layers.mpu import mp_ops +from paddle.framework import core + +__all__ = ["quick_lora"] + + +def is_fused_matmul_bias_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu(): + return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue") + return False + + +if is_fused_matmul_bias_supported(): + linear_func = paddle.incubate.nn.functional.fused_linear +else: + linear_func = paddle.nn.functional.linear + + +def quick_lora( + input: paddle.Tensor, + lora_A: paddle.Tensor, + lora_B: paddle.Tensor, + weight: paddle.Tensor, + bias: paddle.Tensor = None, + scaling: float = 1.0, + is_column: bool = False, + is_row: bool = False, + group=None, + world_size: int = 1, +): + r""" + Definition of the quick_lora function for efficient low-rank adaptation (LORA) operations + + Parameters: + input: The input data for the LORA operation + lora_A: The LORA matrix A + lora_B: The LORA matrix B + weight: The weight matrix + bias: The bias vector (optional, defaults to None) + scaling: The scaling factor (optional, defaults to 1.0) + is_column: Flag indicating whether to perform LORA operation by column (optional, defaults to False) + is_row: Flag indicating whether to perform LORA operation by row (optional, defaults to False) + group: Group information (optional, defaults to None) + world_size: World size for distributed operations (optional, defaults to 1) + + Returns: + The result of the LORA operation based on the specified parameters + + """ + assert weight.stop_gradient, "When using Quick LoRA, it is necessary that weight.stop_gradient is set to True." + if bias is not None: + assert bias.stop_gradient, "When using Quick LoRA, it is necessary that bias.stop_gradient is set to True." + + input_stop_gradient = input.stop_gradient + if is_column: + # If is_column is True, apply the LORA operation by column using the ColumnQuickLora class + return ColumnQuickLora.apply( + input, lora_A, lora_B, weight, bias, scaling, group, input_stop_gradient=input_stop_gradient + ) + elif is_row: + # If is_row is True, apply the LORA operation by row using the RowQuickLora class + return RowQuickLora.apply( + input, lora_A, lora_B, weight, bias, scaling, group, world_size, input_stop_gradient=input_stop_gradient + ) + else: + # If neither is_column nor is_row is True, apply the regular LORA operation using the QuickLora class + return QuickLora.apply(input, lora_A, lora_B, weight, bias, scaling, input_stop_gradient=input_stop_gradient) + + +class QuickLora(PyLayer): + @staticmethod + def forward( + ctx, + input, + lora_A, + lora_B, + weight, + bias: paddle.Tensor = None, + scaling: float = 1.0, + input_stop_gradient: bool = False, + ): + merged_weight = paddle.addmm(weight, lora_A, lora_B, beta=1.0, alpha=scaling) + ctx.input_stop_gradient = input_stop_gradient + ctx.scaling = scaling + ctx.save_for_backward(input, weight, lora_A, lora_B) + result = linear_func(input, merged_weight, bias) + return result + + @staticmethod + def backward(ctx, grad_output): + input, weight, lora_A, lora_B = ctx.saved_tensor() + grad_output = grad_output.flatten(0, 1) + input_fused = input.flatten(0, 1) + lora_B_input_grad = paddle.matmul(grad_output, lora_B, transpose_y=True) + input_grad = None + + if not ctx.input_stop_gradient: + input_grad = paddle.addmm( + paddle.matmul(grad_output, weight, transpose_y=True), + lora_B_input_grad, + lora_A.T, + beta=1.0, + alpha=ctx.scaling, + ).reshape(input.shape) + + lora_A_grad = paddle.matmul(input_fused, lora_B_input_grad, transpose_x=True) * ctx.scaling + + lora_B_grad = paddle.matmul(paddle.matmul(input_fused, lora_A), grad_output, transpose_x=True) * ctx.scaling + + return input_grad, lora_A_grad, lora_B_grad + + +class ColumnQuickLora(PyLayer): + @staticmethod + def forward( + ctx, input, lora_A, lora_B, weight, bias=None, scaling=1.0, group=None, input_stop_gradient: bool = False + ): + merged_weight = paddle.addmm(weight, lora_A, lora_B, beta=1.0, alpha=scaling) + ctx.group = group + ctx.op_type = _get_reduce_op(ReduceOp.SUM, "_c_identity") + ctx.input_stop_gradient = input_stop_gradient + ctx.scaling = scaling + ctx.save_for_backward(input, weight, lora_A, lora_B) + result = linear_func(input, merged_weight, bias) + return result + + @staticmethod + def backward(ctx, grad_output): + input, weight, lora_A, lora_B = ctx.saved_tensor() + grad_output = grad_output.flatten(0, 1) + input_fused = input.flatten(0, 1) + lora_B_input_grad = paddle.matmul(grad_output, lora_B, transpose_y=True) + input_grad = None + if not ctx.input_stop_gradient: + input_grad = paddle.addmm( + paddle.matmul(grad_output, weight, transpose_y=True), + lora_B_input_grad, + lora_A.T, + beta=1.0, + alpha=ctx.scaling, + ).reshape(input.shape) + + if ctx.group is not None: + ctx.group.process_group.all_reduce_on_calc_stream(lora_B_input_grad, ctx.op_type) + lora_A_grad = paddle.matmul(input_fused, lora_B_input_grad, transpose_x=True) * ctx.scaling + + lora_B_grad = paddle.matmul(paddle.matmul(input_fused, lora_A), grad_output, transpose_x=True) * ctx.scaling + + return input_grad, lora_A_grad, lora_B_grad + + +class RowQuickLora(PyLayer): + @staticmethod + def forward( + ctx, + input, + lora_A, + lora_B, + weight, + bias=None, + scaling: float = 1.0, + group=None, + world_size: int = 1, + input_stop_gradient: bool = False, + ): + if world_size > 1 and bias is not None: + bias = paddle.scale(bias, 1.0 / world_size) + merged_weight = paddle.addmm(weight, lora_A, lora_B, beta=1.0, alpha=scaling) + ctx.input_stop_gradient = input_stop_gradient + ctx.group = group + ctx.scaling = scaling + ctx.save_for_backward(input, weight, lora_A, lora_B) + result = linear_func(input, merged_weight, bias) + return result + + @staticmethod + def backward(ctx, grad_output): + input, weight, lora_A, lora_B = ctx.saved_tensor() + + grad_output = grad_output.flatten(0, 1) + input_fused = input.flatten(0, 1) + + lora_B_input_grad = paddle.matmul(grad_output, lora_B, transpose_y=True) + + input_grad = None + if not ctx.input_stop_gradient: + input_grad = paddle.addmm( + paddle.matmul(grad_output, weight, transpose_y=True), + lora_B_input_grad, + lora_A.T, + beta=1.0, + alpha=ctx.scaling, + ).reshape(input.shape) + + lora_A_grad = paddle.matmul(input_fused, lora_B_input_grad, transpose_x=True) * ctx.scaling + + x_lora_A = paddle.matmul(input_fused, lora_A) + if ctx.group is not None: + x_lora_A = mp_ops._mp_allreduce( + x_lora_A, + group=ctx.group, + use_calc_stream=True, + use_model_parallel=True, + ) + lora_B_grad = paddle.matmul(x_lora_A, grad_output, transpose_x=True) * ctx.scaling + return input_grad, lora_A_grad, lora_B_grad diff --git a/tests/llm/test_lora.py b/tests/llm/test_lora.py index d4bec137e8c6..bed84c39d96b 100644 --- a/tests/llm/test_lora.py +++ b/tests/llm/test_lora.py @@ -57,6 +57,8 @@ def test_lora(self): lora_config = load_test_config(self.config_path, "lora", self.model_dir) lora_config["output_dir"] = self.output_dir lora_config["dataset_name_or_path"] = self.data_dir + # use_quick_lora + lora_config["use_quick_lora"] = True with argv_context_guard(lora_config): from finetune_generation import main