Skip to content

Commit

Permalink
fix import
Browse files Browse the repository at this point in the history
  • Loading branch information
JunnYu committed Apr 23, 2024
1 parent dc5a6af commit 7314063
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 1 deletion.
272 changes: 271 additions & 1 deletion paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,44 @@

from .lora_quick_layers import quick_lora

if "npu" in paddle.device.get_all_custom_device_type():

def is_mc2_valid():
return "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0"))


if is_mc2_valid():
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
MC2ColumnSeqParallelLinear,
MC2RowSeqParallelLinear,
)

from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
else:
MC2LoRaRowParallelLinear = None
MC2LoRaColumnParallelLinear = None
MC2ColumnSeqParallelLinear = None
MC2RowSeqParallelLinear = None


try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
AllGatherOp,
ColumnSequenceParallelLinear,
ReduceScatterOp,
RowSequenceParallelLinear,
mark_as_sequence_parallel_parameter,
)
except:

class ColumnSequenceParallelLinear:
pass

class RowSequenceParallelLinear:
pass

AllGatherOp = None
ReduceScatterOp = None
mark_as_sequence_parallel_parameter = None


class LoRALinear(nn.Linear):
Expand Down Expand Up @@ -298,6 +331,123 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
pissa: bool = False,
**kwargs
):
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
if pissa:
raise ValueError("Pissa is not supported in model parallel by now")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights

# compatible
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.lora_B = self.create_parameter(
shape=[r, self.out_features],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_A.is_distributed = True
self.lora_A.split_axis = 0
self.lora_B.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_B)
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
return False # self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = False

def eval(self):
super().eval()
if self.merge_weights and not self.merged:
# Merge the weights and mark it
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = True

def forward(self, x: paddle.Tensor):
if not self.input_is_parallel:
input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
else:
input_mp = x

if not is_mc2_valid():
output_parallel = self.linear(input_mp, self.weight, name=self._name)
output_ = ReduceScatterOp.apply(output_parallel)
result_mp = output_ + self.bias if self.bias is not None else output_
else:
output_ = MC2RowSeqParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
result_mp = output_ + self.bias if self.bias is not None else output_

if not self.merged:
input_mp = self.lora_dropout(input_mp)
if not is_mc2_valid():
input_mp = input_mp @ self.lora_A
input_mp = ReduceScatterOp.apply(input_mp)
else:
input_mp = MC2RowSeqParallelLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
delta_mp = (input_mp @ self.lora_B) * self.scaling
result_mp += delta_mp
return result_mp

def extra_repr(self):
name = f", name={self.name}" if self.name else ""
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class ColumnParallelLoRALinear(ColumnParallelLinear):
def __init__(
self,
Expand Down Expand Up @@ -428,6 +578,126 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
pissa: bool = False,
**kwargs
):
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
if pissa:
raise ValueError("Pissa is not supported in model parallel by now")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights

# compatible
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[in_features, r],
dtype=self._dtype,
is_bias=False,
attr=lora_A_weight_attr,
)
self.lora_A.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_A)

self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
return False # self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = False

def eval(self):
super().eval()
if self.merge_weights and not self.merged:
# Merge the weights and mark it
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = True

def forward(self, x: paddle.Tensor):
if not is_mc2_valid():
if self.is_mp:
input_parallel = AllGatherOp.apply(x)
else:
input_parallel = x
result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
else:
result_mp = MC2ColumnSeqParallelLinear.apply(x, self.weight, self.model_parallel_group)
if self.bias is not None:
result_mp += self.bias

if not self.merged:
input_a = self.lora_dropout(x) @ self.lora_A
if not is_mc2_valid():
input_a = AllGatherOp.apply(input_a)
delta_mp = (input_a @ self.lora_B) * self.scaling
else:
input_a = MC2ColumnSeqParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = input_a * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
else:
result = result_mp
return result

def extra_repr(self):
name = f", name={self.name}" if self.name else ""
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class LoRAMergedLinear(nn.Linear):
# LoRA implemented in a dense layer with merged linear weights for q, k, v
def __init__(
Expand Down
Loading

0 comments on commit 7314063

Please sign in to comment.