-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unified Checkpoint] Support sharding_comm_overlap #9392
Conversation
Thanks for your contribution! |
@@ -2053,6 +2053,9 @@ def get_expected_keys(inputs, keys): | |||
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) | |||
self.optimizer = fleet.distributed_optimizer(self.optimizer) | |||
|
|||
if self.args.enable_sharding_comm_overlap: | |||
model.register_sharding_comm_overlap_hook(self.optimizer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZHUI 看一下这个要不要专门针对uc来打开这个开关
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要的,最小影响到其他策略。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加了一个判断条件,针对split_param开启了再打开
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9392 +/- ##
===========================================
- Coverage 52.96% 52.94% -0.02%
===========================================
Files 676 676
Lines 107827 107836 +9
===========================================
- Hits 57109 57099 -10
- Misses 50718 50737 +19 ☔ View full report in Codecov by Sentry. |
3310514
to
aecf9f1
Compare
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( | ||
model=self.model, | ||
model=model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emm,会不会 state_dict 的 name 前面又套了一层其他的东西
比如 model.model.embedding
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
相关情况已处理
paddlenlp/trainer/trainer.py
Outdated
@@ -2840,8 +2843,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): | |||
else: | |||
opt_state_dict = None | |||
else: | |||
model = self.model | |||
if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emm,我觉得要不你搞一个 uc_with_pp_sharding_comm_overlap
之类的config,你内部单独用吧。
不和 enable_sharding_comm_overlap 搞在一起了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样会更复杂,而且判断逻辑如果和training_args.py存在冲突或对不齐,就糟糕了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* [Unified Checkpoint] Support sharding_comm_overlap (#9392)
PR types
Function optimization
PR changes
Others
Description
Support sharding_comm_overlap.