-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unified Checkpoint] Add split param and refactor code #9240
Conversation
…nto add_split_param
Thanks for your contribution! |
@@ -909,7 +983,160 @@ def unified_checkpoint_into_shards( | |||
return state_dict, shard_file, sharded_index | |||
|
|||
|
|||
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数是不是和load_unified_optimizer_locally大部分逻辑相似
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前初步开发中,后续会修改
…nto add_split_param
3abfe71
to
9bce15b
Compare
9bce15b
to
19071ef
Compare
…nto add_split_param
…nto add_split_param
ec6a76a
to
ae9ddce
Compare
…nto add_split_param
ff0ebc2
to
0d10c4c
Compare
0d10c4c
to
cbbc074
Compare
@@ -0,0 +1,493 @@ | |||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件比较多,干脆直接单独建一个文件夹吧 paddlenlp/trainer/unified_checkpoint/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
load_single_card_optimizer, | ||
save_single_card_checkpoint, | ||
save_single_card_optimizer, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local load的函数也是在这个文件对吗?要不要也拆出去?
@@ -406,30 +402,21 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str) | |||
None | |||
""" | |||
if paddle.distributed.get_world_size() <= 1: | |||
load_single_card_checkpoint(self.args, model, resume_from_checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么args 不需要了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单卡的这个加载不需要读args,是多余的
return | ||
|
||
if self.args.dataset_rank == 0 or self.args.use_expert_parallel: | ||
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True) | ||
|
||
def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir): | ||
def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
提取master_weights的位置发生了改变?
args.sharding_parallel_degree > 1 | ||
and ShardingOption.SHARD_OP in args.sharding | ||
and "split_param" in args.sharding_parallel_config | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个判断直接定义成一个函数吧,出现很多次了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
def distributed_send_recv_splited_param( | ||
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数是去合并 拆开的 参数吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我把函数名改一下吧,改成 merge_splited_param
return optim_state_dict, master_weights | ||
|
||
|
||
def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
处理 tp 合并的代码,还在原来的 unified_checkpoint 主入口 那边 是不是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
get_optimizer_shard_files, | ||
mapping_optimizer_tp_actions, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加 __all__
) | ||
|
||
|
||
def save_file_sync(state_dict, path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些类似的函数,可以放公共地方吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为没有给单卡的支持异步保存,所以单独给他写了一个小函数,单卡专用的,后续再看看怎么合并吧,这次先不处理了。
# save generation config | ||
if model_to_save.can_generate(): | ||
model_to_save.generation_config.save_pretrained(output_dir) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些个 config 的保存,看要不要也封装成公共函数吧,这样修改不容易修改漏掉。现在分支比较多,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
dc4c75a
to
862f86b
Compare
862f86b
to
7678fad
Compare
…nto add_split_param
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .unified_checkpoint import UnifiedCheckpointHandler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__all__
加一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
e78fe82
to
b219ba6
Compare
30fe038
to
dbd13df
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…9240) * [Unified checkpoint] update optimizer async save signal * update paddlepaddle * split param * add save for split param * fix save split_param * add load uc split_param * update uc files * update uc files * update split_param loading * mkdir unified_checkpoint directory * rename file * update async handler * update files --------- Co-authored-by: gongenlei <gongenlei@baidu.com>
…9240) * [Unified checkpoint] update optimizer async save signal * update paddlepaddle * split param * add save for split param * fix save split_param * add load uc split_param * update uc files * update uc files * update split_param loading * mkdir unified_checkpoint directory * rename file * update async handler * update files --------- Co-authored-by: gongenlei <gongenlei@baidu.com>
PR types
New features
PR changes
Others
Description