Skip to content

Commit

Permalink
[fix] Broadcast optimizer state using broadcast_dp without shard-resh…
Browse files Browse the repository at this point in the history
…ard. (#8522)
  • Loading branch information
bo-ke authored Jun 3, 2024
1 parent e71540b commit 440c093
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
8 changes: 5 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
weights_index_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint} -- {weights_file}")

logger.info(f"Loading model from {resume_from_checkpoint} .")

Expand Down Expand Up @@ -2237,7 +2237,7 @@ def _save_checkpoint(self, model, metrics=None):
safe_serialization=True,
)
else:
if self.dp_group.rank > 0:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
)
Expand Down Expand Up @@ -2525,7 +2525,9 @@ def _load_optimizer_and_scheduler(self, checkpoint):
if self.args.local_rank != -1:
dist.barrier()
if self.args.use_expert_parallel:
opt_state_dict = broadcast_moe_optimizer(opt_state_dict)
opt_state_dict = broadcast_moe_optimizer(
opt_state_dict, broadcast_dp=not self.args.should_load_sharding_stage1_model
)
else:
if not self.args.should_load_sharding_stage1_model:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
Expand Down
7 changes: 5 additions & 2 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def broadcast_dp_optimizer(state_dict):
return state_dict


def broadcast_moe_optimizer(state_dict):
def broadcast_moe_optimizer(state_dict, broadcast_dp=True):

try:
hcg = fleet.get_hybrid_communicate_group()
Expand Down Expand Up @@ -270,7 +270,10 @@ def _broadcast_moe_optimizer_state(state_dict):
base_state_dict.update(buf[2])
return base_state_dict

base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if broadcast_dp:
base_state_dict = broadcast_dp_optimizer(state_dict)
else:
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
Expand Down

0 comments on commit 440c093

Please sign in to comment.