-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PEFT]Add LoRA-GA #9592
[PEFT]Add LoRA-GA #9592
Changes from 8 commits
d8671df
b953519
6c2da73
d7a3865
f6ee622
00c76f5
6a8d175
0421446
7509939
468d549
b435ed6
4d7603f
86c5b33
44f633e
275c623
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,6 +163,8 @@ | |
) | ||
|
||
self.forward = self.model.forward | ||
if lora_config.loraga: | ||
self.loraga_init_dict = {} | ||
|
||
logger.info("Mark only lora and trainable_module as trainable.") | ||
self.mark_only_lora_as_trainable() | ||
|
@@ -269,7 +271,7 @@ | |
tp_actions if pre_tensor_parallel_split else None, | ||
expected_keys, | ||
) | ||
error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "") | ||
error_msgs += _load_state_dict_into_model(lora_model, state_dict, "") | ||
del state_dict | ||
gc.collect() | ||
|
||
|
@@ -319,6 +321,41 @@ | |
warnings.filterwarnings( | ||
action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False | ||
) | ||
|
||
model_state_dict = self.model.state_dict() | ||
if self.lora_config.loraga: | ||
|
||
def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): | ||
if isinstance(concat_tensor, np.ndarray): | ||
final_lora, init_lora = np.split(concat_tensor, 2, axis=axis) | ||
init_lora = paddle.to_tensor(init_lora) | ||
else: | ||
final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis) | ||
init_dict[name] = init_lora | ||
state_dict[name] = final_lora | ||
return final_lora, init_lora | ||
|
||
for name in state_dict.keys(): | ||
if "lora_A" in name: | ||
concat_lora_A = state_dict[name] | ||
final_loraA, init_loraA = process_split_and_assign( | ||
name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict | ||
) | ||
|
||
loraB_name = name.replace("lora_A", "lora_B") | ||
concat_lora_B = state_dict[loraB_name] | ||
final_loraB, init_loraB = process_split_and_assign( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种没有用到的变量 可以直接 _ 符号替换 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这里我改一下 |
||
loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict | ||
) | ||
|
||
base_name = name.replace("lora_A", "weight") | ||
|
||
# Reinit base model | ||
offset = init_loraA.cuda() @ init_loraB.cuda() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥这里都需要执行 .cuda()?,我看返回的本身就是在设备上了?直接.cuda的话多硬件感觉会有点问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. state_dict就是在cpu上的,split了以后除非使用paddle.to_tensor,不然还是在cpu上。这块可以统一使用to_tensor或者to(target_device)替代掉.cuda()么? |
||
ori_weight = model_state_dict[base_name] | ||
model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) | ||
del model_state_dict | ||
gc.collect() | ||
self.model.set_state_dict(state_dict) | ||
logger.info("Load lora weight successfully") | ||
|
||
|
@@ -400,8 +437,9 @@ | |
|
||
lora_config_to_save = LoRAConfig(**self.lora_config.to_dict()) | ||
|
||
trainable_state_dict = self.get_trainable_state_dict(concat_init_lora=lora_config_to_save.loraga) | ||
|
||
if merge_tensor_parallel and lora_config_to_save.tensor_parallel_degree > 1: | ||
trainable_state_dict = self.get_trainable_state_dict() | ||
trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict) | ||
if not is_main_process: | ||
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save") | ||
|
@@ -410,7 +448,6 @@ | |
variant = "_".join([x for x in variant.split("_") if "tp" not in x]) | ||
lora_config_to_save.tensor_parallel_degree = -1 | ||
else: | ||
trainable_state_dict = self.get_trainable_state_dict() | ||
if lora_config_to_save.tensor_parallel_degree > 1: | ||
if variant is None: | ||
variant = weight_name_suffix() | ||
|
@@ -641,12 +678,19 @@ | |
original_module.bias = module.bias | ||
setattr(parent_module, attribute_chain[-1], original_module) | ||
|
||
def get_trainable_state_dict(self): | ||
def get_trainable_state_dict(self, concat_init_lora=False): | ||
trainable_state_dict = OrderedDict() | ||
for name, weight in self.model.state_dict().items(): | ||
# get lora parameter & QAT scale parameter | ||
if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name: | ||
trainable_state_dict[name] = weight | ||
if concat_init_lora: | ||
if "lora_A" in name: | ||
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1) | ||
else: | ||
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0) | ||
else: | ||
trainable_state_dict[name] = weight | ||
|
||
return trainable_state_dict | ||
|
||
def print_trainable_parameters(self) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块为啥修改了lora_model.model为lora_model?是原来的写法有误吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在的写法没问题吧,我理解按现在的代码,不加LoRA-GA的情况下,这两种写法是等价的,所以改了也不会影响现在的功能。加了LoRA-GA以后我想把它的加载逻辑统一在LoRAModel.set_state_dict()里改掉,如果还是使用现在的写法的话from_pretrained的时候就走不到LoRAModel.set_state_dict()了