Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Add split param and refactor code #9240

Merged
merged 21 commits into from
Oct 28, 2024

Conversation

DesmonDay
Copy link
Contributor

@DesmonDay DesmonDay commented Oct 10, 2024

PR types

New features

PR changes

Others

Description

  1. Support sharding stage1 v2 for unified checkpoint.
  2. Refactor uc code.

Copy link

paddle-bot bot commented Oct 10, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented Oct 10, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented Oct 10, 2024

Codecov Report

Attention: Patch coverage is 11.13549% with 1620 lines in your changes missing coverage. Please review.

Project coverage is 52.84%. Comparing base (81ffc78) to head (dbd13df).
Report is 3 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/trainer/unified_checkpoint/utils.py 12.07% 364 Missing ⚠️
...p/trainer/unified_checkpoint/unified_checkpoint.py 11.34% 297 Missing ⚠️
...ddlenlp/trainer/unified_checkpoint/load_dynamic.py 9.44% 259 Missing ⚠️
...r/unified_checkpoint/sharding_split_param_utils.py 7.97% 173 Missing ⚠️
...nlp/trainer/unified_checkpoint/check_completion.py 9.37% 145 Missing ⚠️
...dlenlp/trainer/unified_checkpoint/async_handler.py 11.32% 141 Missing ⚠️
paddlenlp/trainer/unified_checkpoint/load_local.py 12.12% 116 Missing ⚠️
...rainer/unified_checkpoint/load_save_single_card.py 15.32% 116 Missing ⚠️
paddlenlp/utils/nested.py 14.28% 6 Missing ⚠️
paddlenlp/trainer/training_args.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9240      +/-   ##
===========================================
+ Coverage    52.78%   52.84%   +0.06%     
===========================================
  Files          661      669       +8     
  Lines       106945   107240     +295     
===========================================
+ Hits         56450    56671     +221     
- Misses       50495    50569      +74     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ZHUI ZHUI requested review from ZHUI and DrownFish19 October 11, 2024 06:42
@@ -909,7 +983,160 @@ def unified_checkpoint_into_shards(
return state_dict, shard_file, sharded_index


def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数是不是和load_unified_optimizer_locally大部分逻辑相似

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前初步开发中,后续会修改

@DesmonDay DesmonDay changed the title [Unified Checkpoint] Add split param [WIP][Unified Checkpoint] Add split param Oct 12, 2024
@DesmonDay DesmonDay force-pushed the add_split_param branch 2 times, most recently from 3abfe71 to 9bce15b Compare October 14, 2024 09:13
@DesmonDay DesmonDay changed the title [WIP][Unified Checkpoint] Add split param [WIP][Unified Checkpoint] Add split param, refactor code Oct 14, 2024
@DesmonDay DesmonDay changed the title [WIP][Unified Checkpoint] Add split param, refactor code [WIP][Unified Checkpoint] Add split param and refactor code Oct 14, 2024
@@ -0,0 +1,493 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件比较多,干脆直接单独建一个文件夹吧 paddlenlp/trainer/unified_checkpoint/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

load_single_card_optimizer,
save_single_card_checkpoint,
save_single_card_optimizer,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local load的函数也是在这个文件对吗?要不要也拆出去?

@@ -406,30 +402,21 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
None
"""
if paddle.distributed.get_world_size() <= 1:
load_single_card_checkpoint(self.args, model, resume_from_checkpoint)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么args 不需要了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单卡的这个加载不需要读args,是多余的

return

if self.args.dataset_rank == 0 or self.args.use_expert_parallel:
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir):
def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, output_dir, signal_dir):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

提取master_weights的位置发生了改变?

args.sharding_parallel_degree > 1
and ShardingOption.SHARD_OP in args.sharding
and "split_param" in args.sharding_parallel_config
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断直接定义成一个函数吧,出现很多次了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def distributed_send_recv_splited_param(
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数是去合并 拆开的 参数吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我把函数名改一下吧,改成 merge_splited_param

return optim_state_dict, master_weights


def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

处理 tp 合并的代码,还在原来的 unified_checkpoint 主入口 那边 是不是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

get_optimizer_shard_files,
mapping_optimizer_tp_actions,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__all__

)


def save_file_sync(state_dict, path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些类似的函数,可以放公共地方吗?

Copy link
Contributor Author

@DesmonDay DesmonDay Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为没有给单卡的支持异步保存,所以单独给他写了一个小函数,单卡专用的,后续再看看怎么合并吧,这次先不处理了。

# save generation config
if model_to_save.can_generate():
model_to_save.generation_config.save_pretrained(output_dir)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些个 config 的保存,看要不要也封装成公共函数吧,这样修改不容易修改漏掉。现在分支比较多,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# See the License for the specific language governing permissions and
# limitations under the License.

from .unified_checkpoint import UnifiedCheckpointHandler
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__all__ 加一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@DesmonDay DesmonDay changed the title [WIP][Unified Checkpoint] Add split param and refactor code [Unified Checkpoint] Add split param and refactor code Oct 25, 2024
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit c9d5673 into PaddlePaddle:develop Oct 28, 2024
8 of 12 checks passed
winter-wang pushed a commit to winter-wang/PaddleNLP that referenced this pull request Oct 28, 2024
…9240)

* [Unified checkpoint] update optimizer async save signal

* update paddlepaddle

* split param

* add save for split param

* fix save split_param

* add load uc split_param

* update uc files

* update uc files

* update split_param loading

* mkdir unified_checkpoint directory

* rename file

* update async handler

* update files

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
DesmonDay added a commit to DesmonDay/PaddleNLP that referenced this pull request Oct 28, 2024
…9240)

* [Unified checkpoint] update optimizer async save signal

* update paddlepaddle

* split param

* add save for split param

* fix save split_param

* add load uc split_param

* update uc files

* update uc files

* update split_param loading

* mkdir unified_checkpoint directory

* rename file

* update async handler

* update files

---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants