From 440c093867a6113ae5cc60659333808a51f3b776 Mon Sep 17 00:00:00 2001 From: Ferrebo Date: Mon, 3 Jun 2024 09:55:03 +0800 Subject: [PATCH] [fix] Broadcast optimizer state using broadcast_dp without shard-reshard. (#8522) --- paddlenlp/trainer/trainer.py | 8 +++++--- paddlenlp/trainer/utils/helper.py | 7 +++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index c3e44a4e838c..14a0c6d6983c 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -582,7 +582,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): weights_index_file, ] ): - raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint} -- {weights_file}") logger.info(f"Loading model from {resume_from_checkpoint} .") @@ -2237,7 +2237,7 @@ def _save_checkpoint(self, model, metrics=None): safe_serialization=True, ) else: - if self.dp_group.rank > 0: + if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME) ) @@ -2525,7 +2525,9 @@ def _load_optimizer_and_scheduler(self, checkpoint): if self.args.local_rank != -1: dist.barrier() if self.args.use_expert_parallel: - opt_state_dict = broadcast_moe_optimizer(opt_state_dict) + opt_state_dict = broadcast_moe_optimizer( + opt_state_dict, broadcast_dp=not self.args.should_load_sharding_stage1_model + ) else: if not self.args.should_load_sharding_stage1_model: opt_state_dict = broadcast_dp_optimizer(opt_state_dict) diff --git a/paddlenlp/trainer/utils/helper.py b/paddlenlp/trainer/utils/helper.py index 12aec88bc41b..8e4c22e908f5 100644 --- a/paddlenlp/trainer/utils/helper.py +++ b/paddlenlp/trainer/utils/helper.py @@ -228,7 +228,7 @@ def broadcast_dp_optimizer(state_dict): return state_dict -def broadcast_moe_optimizer(state_dict): +def broadcast_moe_optimizer(state_dict, broadcast_dp=True): try: hcg = fleet.get_hybrid_communicate_group() @@ -270,7 +270,10 @@ def _broadcast_moe_optimizer_state(state_dict): base_state_dict.update(buf[2]) return base_state_dict - base_state_dict = _broadcast_moe_optimizer_state(state_dict) + if broadcast_dp: + base_state_dict = broadcast_dp_optimizer(state_dict) + else: + base_state_dict = _broadcast_moe_optimizer_state(state_dict) if data_parallel_rank > 0: master_weight = state_dict.pop("master_weights", {}) base_state_dict.update(state_dict)