diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index a40f3fff64a8..637ef086fcbe 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from contextlib import nullcontext from typing import Optional import paddle @@ -22,6 +23,7 @@ from paddle.distributed.fleet.meta_parallel import ( ColumnParallelLinear, RowParallelLinear, + get_rng_state_tracker, ) from ...transformers import linear_utils @@ -50,6 +52,10 @@ from .lora_quick_layers import quick_lora +def rng_ctx(is_mp: bool, in_dynamic_mode: bool): + return get_rng_state_tracker().rng_state() if (is_mp and in_dynamic_mode) else nullcontext() + + class LoRALinear(nn.Linear): # LoRA implemented in a dense layer def __init__( @@ -198,14 +204,15 @@ def __init__( self.name = self._name # Actual trainable parameters - self.lora_A = self.create_parameter( - shape=[self.input_size_per_partition, r], - dtype=self._dtype, - is_bias=False, - attr=paddle.ParamAttr( - initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") - ), - ) + with rng_ctx(self.is_mp, paddle.in_dynamic_mode()): + self.lora_A = self.create_parameter( + shape=[self.input_size_per_partition, r], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + ) self.lora_B = self.create_parameter( shape=[r, self.out_features], dtype=self._dtype, @@ -345,14 +352,15 @@ def __init__( self.name = self._name # Actual trainable parameters - self.lora_A = self.create_parameter( - shape=[self.input_size_per_partition, r], - dtype=self._dtype, - is_bias=False, - attr=paddle.ParamAttr( - initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") - ), - ) + with rng_ctx(self.is_mp, paddle.in_dynamic_mode()): + self.lora_A = self.create_parameter( + shape=[self.input_size_per_partition, r], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + ) self.lora_B = self.create_parameter( shape=[r, self.out_features], dtype=self._dtype, @@ -468,15 +476,16 @@ def __init__( attr=lora_A_weight_attr, ) self.lora_A.is_distributed = False - self.lora_B = self.create_parameter( - shape=[r, self.output_size_per_partition], - dtype=self._dtype, - is_bias=False, - attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.Constant(value=0.0), - learning_rate=lora_plus_scale, - ), - ) + with rng_ctx(self.is_mp, paddle.in_dynamic_mode()): + self.lora_B = self.create_parameter( + shape=[r, self.output_size_per_partition], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) self.lora_B.is_distributed = True self.lora_B.split_axis = 1 @@ -599,15 +608,16 @@ def __init__( self.lora_A.is_distributed = False mark_as_sequence_parallel_parameter(self.lora_A) - self.lora_B = self.create_parameter( - shape=[r, self.output_size_per_partition], - dtype=self._dtype, - is_bias=False, - attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.Constant(value=0.0), - learning_rate=lora_plus_scale, - ), - ) + with rng_ctx(self.is_mp, paddle.in_dynamic_mode()): + self.lora_B = self.create_parameter( + shape=[r, self.output_size_per_partition], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) self.lora_B.is_distributed = True self.lora_B.split_axis = 1