Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT]Add LoRA-GA #9592

Merged
merged 15 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class LoRAConfig:
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
loraga: bool = field(default=False, metadata={"help": "Whether to LoRA-GA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
Expand Down
54 changes: 49 additions & 5 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@
)

self.forward = self.model.forward
if lora_config.loraga:
self.loraga_init_dict = {}

Check warning on line 167 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L167

Added line #L167 was not covered by tests

logger.info("Mark only lora and trainable_module as trainable.")
self.mark_only_lora_as_trainable()
Expand Down Expand Up @@ -269,7 +271,7 @@
tp_actions if pre_tensor_parallel_split else None,
expected_keys,
)
error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "")
error_msgs += _load_state_dict_into_model(lora_model, state_dict, "")

Check warning on line 274 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L274

Added line #L274 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块为啥修改了lora_model.model为lora_model?是原来的写法有误吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在的写法没问题吧,我理解按现在的代码,不加LoRA-GA的情况下,这两种写法是等价的,所以改了也不会影响现在的功能。加了LoRA-GA以后我想把它的加载逻辑统一在LoRAModel.set_state_dict()里改掉,如果还是使用现在的写法的话from_pretrained的时候就走不到LoRAModel.set_state_dict()了

del state_dict
gc.collect()

Expand Down Expand Up @@ -319,6 +321,41 @@
warnings.filterwarnings(
action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False
)

model_state_dict = self.model.state_dict()
if self.lora_config.loraga:

def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict):
if isinstance(concat_tensor, np.ndarray):
final_lora, init_lora = np.split(concat_tensor, 2, axis=axis)
init_lora = paddle.to_tensor(init_lora)

Check warning on line 331 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L328-L331

Added lines #L328 - L331 were not covered by tests
else:
final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis)
init_dict[name] = init_lora
state_dict[name] = final_lora
return final_lora, init_lora

Check warning on line 336 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L333-L336

Added lines #L333 - L336 were not covered by tests

for name in state_dict.keys():
if "lora_A" in name:
concat_lora_A = state_dict[name]
final_loraA, init_loraA = process_split_and_assign(

Check warning on line 341 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L338-L341

Added lines #L338 - L341 were not covered by tests
name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict
)

loraB_name = name.replace("lora_A", "lora_B")
concat_lora_B = state_dict[loraB_name]
final_loraB, init_loraB = process_split_and_assign(

Check warning on line 347 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L345-L347

Added lines #L345 - L347 were not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种没有用到的变量 可以直接 _ 符号替换

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,这里我改一下

loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict
)

base_name = name.replace("lora_A", "weight")

Check warning on line 351 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L351

Added line #L351 was not covered by tests

# Reinit base model
offset = init_loraA.cuda() @ init_loraB.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥这里都需要执行 .cuda()?,我看返回的本身就是在设备上了?直接.cuda的话多硬件感觉会有点问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state_dict就是在cpu上的,split了以后除非使用paddle.to_tensor,不然还是在cpu上。这块可以统一使用to_tensor或者to(target_device)替代掉.cuda()么?

ori_weight = model_state_dict[base_name]
model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset)

Check warning on line 356 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L354-L356

Added lines #L354 - L356 were not covered by tests
del model_state_dict
gc.collect()
self.model.set_state_dict(state_dict)
logger.info("Load lora weight successfully")

Expand Down Expand Up @@ -400,8 +437,9 @@

lora_config_to_save = LoRAConfig(**self.lora_config.to_dict())

trainable_state_dict = self.get_trainable_state_dict(concat_init_lora=lora_config_to_save.loraga)

if merge_tensor_parallel and lora_config_to_save.tensor_parallel_degree > 1:
trainable_state_dict = self.get_trainable_state_dict()
trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict)
if not is_main_process:
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save")
Expand All @@ -410,7 +448,6 @@
variant = "_".join([x for x in variant.split("_") if "tp" not in x])
lora_config_to_save.tensor_parallel_degree = -1
else:
trainable_state_dict = self.get_trainable_state_dict()
if lora_config_to_save.tensor_parallel_degree > 1:
if variant is None:
variant = weight_name_suffix()
Expand Down Expand Up @@ -641,12 +678,19 @@
original_module.bias = module.bias
setattr(parent_module, attribute_chain[-1], original_module)

def get_trainable_state_dict(self):
def get_trainable_state_dict(self, concat_init_lora=False):
trainable_state_dict = OrderedDict()
for name, weight in self.model.state_dict().items():
# get lora parameter & QAT scale parameter
if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name:
trainable_state_dict[name] = weight
if concat_init_lora:
if "lora_A" in name:
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1)

Check warning on line 688 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L687-L688

Added lines #L687 - L688 were not covered by tests
else:
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0)

Check warning on line 690 in paddlenlp/peft/lora/lora_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_model.py#L690

Added line #L690 was not covered by tests
else:
trainable_state_dict[name] = weight

return trainable_state_dict

def print_trainable_parameters(self) -> None:
Expand Down
Loading