diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index e0c79c47a87a..66a0d0c0f520 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -27,11 +27,44 @@ from .lora_quick_layers import quick_lora -if "npu" in paddle.device.get_all_custom_device_type(): + +def is_mc2_valid(): + return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")) + + +if is_mc2_valid(): + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, + ) + from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear else: MC2LoRaRowParallelLinear = None MC2LoRaColumnParallelLinear = None + MC2ColumnSeqParallelLinear = None + MC2RowSeqParallelLinear = None + + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + AllGatherOp, + ColumnSequenceParallelLinear, + ReduceScatterOp, + RowSequenceParallelLinear, + mark_as_sequence_parallel_parameter, + ) +except: + + class ColumnSequenceParallelLinear: + pass + + class RowSequenceParallelLinear: + pass + + AllGatherOp = None + ReduceScatterOp = None + mark_as_sequence_parallel_parameter = None class LoRALinear(nn.Linear): @@ -298,6 +331,123 @@ def extra_repr(self): return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" +class RowSequenceParallelLoRALinear(RowSequenceParallelLinear): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus_scale: float = 1.0, + merge_weights: bool = True, + use_quick_lora: bool = False, + pissa: bool = False, + **kwargs + ): + RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) + if not isinstance(r, int) or r <= 0: + raise ValueError("Lora rank r should be a positive integer") + if pissa: + raise ValueError("Pissa is not supported in model parallel by now") + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + # compatible + self.name = self._name + + # Actual trainable parameters + self.lora_A = self.create_parameter( + shape=[self.input_size_per_partition, r], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + ) + self.lora_B = self.create_parameter( + shape=[r, self.out_features], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + + self.lora_A.is_distributed = True + self.lora_A.split_axis = 0 + self.lora_B.is_distributed = False + mark_as_sequence_parallel_parameter(self.lora_B) + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) + + # Freezing the pre-trained weight matrix + self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + # TODO(@gexiao): support qlora + return False # self._use_quick_lora and self.training and not self.merged + + def train(self): + super().train() + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = False + + def eval(self): + super().eval() + if self.merge_weights and not self.merged: + # Merge the weights and mark it + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = True + + def forward(self, x: paddle.Tensor): + if not self.input_is_parallel: + input_mp = mp_ops._c_split(x, group=self.model_parallel_group) + else: + input_mp = x + + if not is_mc2_valid(): + output_parallel = self.linear(input_mp, self.weight, name=self._name) + output_ = ReduceScatterOp.apply(output_parallel) + result_mp = output_ + self.bias if self.bias is not None else output_ + else: + output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) + result_mp = output_ + self.bias if self.bias is not None else output_ + + if not self.merged: + input_mp = self.lora_dropout(input_mp) + if not is_mc2_valid(): + input_mp = input_mp @ self.lora_A + input_mp = ReduceScatterOp.apply(input_mp) + else: + input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group) + delta_mp = (input_mp @ self.lora_B) * self.scaling + result_mp += delta_mp + return result_mp + + def extra_repr(self): + name = f", name={self.name}" if self.name else "" + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" + + class ColumnParallelLoRALinear(ColumnParallelLinear): def __init__( self, @@ -428,6 +578,126 @@ def extra_repr(self): return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" +class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus_scale: float = 1.0, + merge_weights: bool = True, + lora_A_weight_attr: Optional[paddle.ParamAttr] = None, + use_quick_lora: bool = False, + pissa: bool = False, + **kwargs + ): + ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) + if not isinstance(r, int) or r <= 0: + raise ValueError("Lora rank r should be a positive integer") + if pissa: + raise ValueError("Pissa is not supported in model parallel by now") + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + # compatible + self.name = self._name + + # Actual trainable parameters + self.lora_A = self.create_parameter( + shape=[in_features, r], + dtype=self._dtype, + is_bias=False, + attr=lora_A_weight_attr, + ) + self.lora_A.is_distributed = False + mark_as_sequence_parallel_parameter(self.lora_A) + + self.lora_B = self.create_parameter( + shape=[r, self.output_size_per_partition], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + + self.lora_B.is_distributed = True + self.lora_B.split_axis = 1 + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) + + # Freezing the pre-trained weight matrix + self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + # TODO(@gexiao): support qlora + return False # self._use_quick_lora and self.training and not self.merged + + def train(self): + super().train() + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = False + + def eval(self): + super().eval() + if self.merge_weights and not self.merged: + # Merge the weights and mark it + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = True + + def forward(self, x: paddle.Tensor): + if not is_mc2_valid(): + if self.is_mp: + input_parallel = AllGatherOp.apply(x) + else: + input_parallel = x + result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name) + else: + result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group) + if self.bias is not None: + result_mp += self.bias + + if not self.merged: + input_a = self.lora_dropout(x) @ self.lora_A + if not is_mc2_valid(): + input_a = AllGatherOp.apply(input_a) + delta_mp = (input_a @ self.lora_B) * self.scaling + else: + input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) + delta_mp = input_a * self.scaling + result_mp += delta_mp + + if self.gather_output and self.is_mp: + result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) + else: + result = result_mp + return result + + def extra_repr(self): + name = f", name={self.name}" if self.name else "" + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" + + class LoRAMergedLinear(nn.Linear): # LoRA implemented in a dense layer with merged linear weights for q, k, v def __init__( diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 1bbd0284823c..57d3bb3f2205 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -48,10 +48,12 @@ from .lora_layers import ( ColumnParallelLoRALinear, ColumnParallelLoRAMergedLinear, + ColumnSequenceParallelLoRALinear, LoRAConv2D, LoRALinear, LoRAMergedLinear, RowParallelLoRALinear, + RowSequenceParallelLoRALinear, ) try: @@ -73,6 +75,19 @@ ColumnParallelQuantizationLoRALinear = None RowParallelQuantizationLoRALinear = None +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ) +except: + + class ColumnSequenceParallelLinear: + pass + + class RowSequenceParallelLinear: + pass + class LoRAModel(nn.Layer): # TODO:lugimzzz support restore in following PR @@ -454,6 +469,60 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) # Lora column parallel will spilt lora A matrix self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, ColumnSequenceParallelLinear): + # recover the original output_features + output_features = module.weight.shape[1] * module.world_size + lora_module = ColumnSequenceParallelLoRALinear( + in_features=module.weight.shape[0], + out_features=output_features, + gather_output=module.gather_output, + has_bias=module.bias is not None, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + pissa=lora_config.pissa, + merge_weights=lora_config.merge_weights, + lora_A_weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + negative_slope=math.sqrt(5), nonlinearity="leaky_relu" + ) + ), + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora B matrix + self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) + + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, RowSequenceParallelLinear): + # recover the original output_features + lora_module = RowSequenceParallelLoRALinear( + in_features=module.weight.shape[0] * module.world_size, + out_features=module.weight.shape[1], + has_bias=module.bias is not None, + input_is_parallel=module.input_is_parallel, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + pissa=lora_config.pissa, + merge_weights=lora_config.merge_weights, + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora A matrix + self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) + # for lora qat if self.lora_config.do_qat: self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) @@ -597,6 +666,8 @@ def mark_only_lora_as_trainable(self) -> None: or isinstance(layer, LoRAConv2D) or isinstance(layer, ColumnParallelLoRALinear) or isinstance(layer, RowParallelLoRALinear) + or isinstance(layer, ColumnSequenceParallelLoRALinear) + or isinstance(layer, RowSequenceParallelLoRALinear) or isinstance(layer, LoRAMergedLinear) or isinstance(layer, ColumnParallelLoRAMergedLinear) or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear)) @@ -684,9 +755,11 @@ def restore_original_model(self): self._find_and_restore_module(layer_name) elif ( isinstance(layer, ColumnParallelLoRALinear) + or isinstance(layer, ColumnSequenceParallelLoRALinear) or isinstance(layer, LoRAConv2D) or isinstance(layer, ColumnParallelLoRAMergedLinear) or isinstance(layer, RowParallelLoRALinear) + or isinstance(layer, RowSequenceParallelLoRALinear) or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear)) or ( ColumnParallelQuantizationLoRALinear is not None