-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PEFT]Add LoRA-GA #9592
Merged
Merged
[PEFT]Add LoRA-GA #9592
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d8671df
Support LoRA-GA initialization
greycooker b953519
Merge branch 'LoRA-GA' of github.com:greycooker/PaddleNLP into LoRA-GA
greycooker 6c2da73
modify loraga_reinit
greycooker d7a3865
Merge branch 'LoRA-GA' of github.com:greycooker/PaddleNLP into LoRA-GA
greycooker f6ee622
Support multi GPU initialization
greycooker 00c76f5
Merge branch 'LoRA-GA' of github.com:greycooker/PaddleNLP into LoRA-GA
greycooker 6a8d175
support resume training and gradient offlaod hook
greycooker 0421446
remove trl/llm_utils.py
greycooker 7509939
Merge branch 'PaddlePaddle:develop' into LoRA-GA
greycooker 468d549
use loraga trainer
greycooker b435ed6
fix comment
greycooker 4d7603f
Merge remote-tracking branch 'origin1/develop' into LoRA-GA
greycooker 86c5b33
fix loraga
greycooker 44f633e
Merge branch 'PaddlePaddle:develop' into LoRA-GA
greycooker 275c623
change variable name
greycooker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,6 +163,9 @@ | |
) | ||
|
||
self.forward = self.model.forward | ||
if lora_config.loraga: | ||
self.loraga_init_dict = {} | ||
self.reinit_base_model = False | ||
|
||
logger.info("Mark only lora and trainable_module as trainable.") | ||
self.mark_only_lora_as_trainable() | ||
|
@@ -269,7 +272,7 @@ | |
tp_actions if pre_tensor_parallel_split else None, | ||
expected_keys, | ||
) | ||
error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "") | ||
error_msgs += _load_state_dict_into_model(lora_model, state_dict, "") | ||
del state_dict | ||
gc.collect() | ||
|
||
|
@@ -319,6 +322,41 @@ | |
warnings.filterwarnings( | ||
action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False | ||
) | ||
|
||
model_state_dict = self.model.state_dict() | ||
if self.lora_config.loraga: | ||
|
||
def process_split_and_assign(name, concat_tensor, axis, init_dict, state_dict): | ||
if isinstance(concat_tensor, np.ndarray): | ||
final_lora, init_lora = np.split(concat_tensor, 2, axis=axis) | ||
init_lora = paddle.to_tensor(init_lora) | ||
else: | ||
final_lora, init_lora = paddle.split(concat_tensor, 2, axis=axis) | ||
init_dict[name] = init_lora | ||
state_dict[name] = final_lora | ||
return init_lora | ||
|
||
for name in state_dict.keys(): | ||
if "lora_A" in name: | ||
concat_lora_A = state_dict[name] | ||
init_loraA = process_split_and_assign( | ||
name, concat_lora_A, axis=1, init_dict=self.loraga_init_dict, state_dict=state_dict | ||
) | ||
|
||
loraB_name = name.replace("lora_A", "lora_B") | ||
concat_lora_B = state_dict[loraB_name] | ||
init_loraB = process_split_and_assign( | ||
loraB_name, concat_lora_B, axis=0, init_dict=self.loraga_init_dict, state_dict=state_dict | ||
) | ||
|
||
base_name = name.replace("lora_A", "weight") | ||
if not self.reinit_base_model: | ||
# Reinit base model | ||
offset = init_loraA.cuda() @ init_loraB.cuda() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个看起来也不用主动cuda,使用 paddle.matmul 会根据运行device来切换到cuda显存上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有点奇怪,我试了一下matmul好像是不行的 |
||
ori_weight = model_state_dict[base_name] | ||
model_state_dict[base_name].set_value(ori_weight - self.lora_config.scaling * offset) | ||
del model_state_dict | ||
gc.collect() | ||
self.model.set_state_dict(state_dict) | ||
logger.info("Load lora weight successfully") | ||
|
||
|
@@ -400,8 +438,9 @@ | |
|
||
lora_config_to_save = LoRAConfig(**self.lora_config.to_dict()) | ||
|
||
trainable_state_dict = self.get_trainable_state_dict(concat_init_lora=lora_config_to_save.loraga) | ||
|
||
if merge_tensor_parallel and lora_config_to_save.tensor_parallel_degree > 1: | ||
trainable_state_dict = self.get_trainable_state_dict() | ||
trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict) | ||
if not is_main_process: | ||
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save") | ||
|
@@ -410,7 +449,6 @@ | |
variant = "_".join([x for x in variant.split("_") if "tp" not in x]) | ||
lora_config_to_save.tensor_parallel_degree = -1 | ||
else: | ||
trainable_state_dict = self.get_trainable_state_dict() | ||
if lora_config_to_save.tensor_parallel_degree > 1: | ||
if variant is None: | ||
variant = weight_name_suffix() | ||
|
@@ -641,12 +679,19 @@ | |
original_module.bias = module.bias | ||
setattr(parent_module, attribute_chain[-1], original_module) | ||
|
||
def get_trainable_state_dict(self): | ||
def get_trainable_state_dict(self, concat_init_lora=False): | ||
trainable_state_dict = OrderedDict() | ||
for name, weight in self.model.state_dict().items(): | ||
# get lora parameter & QAT scale parameter | ||
if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name: | ||
trainable_state_dict[name] = weight | ||
if concat_init_lora: | ||
if "lora_A" in name: | ||
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=1) | ||
else: | ||
trainable_state_dict[name] = paddle.concat([weight, self.loraga_init_dict[name]], axis=0) | ||
else: | ||
trainable_state_dict[name] = weight | ||
|
||
return trainable_state_dict | ||
|
||
def print_trainable_parameters(self) -> None: | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块为啥修改了lora_model.model为lora_model?是原来的写法有误吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在的写法没问题吧,我理解按现在的代码,不加LoRA-GA的情况下,这两种写法是等价的,所以改了也不会影响现在的功能。加了LoRA-GA以后我想把它的加载逻辑统一在LoRAModel.set_state_dict()里改掉,如果还是使用现在的写法的话from_pretrained的时候就走不到LoRAModel.set_state_dict()了