Skip to content

Commit

Permalink
[Trainer] update remove_master_weight (#9640)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay authored Dec 20, 2024
1 parent cae6398 commit 9b62712
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,10 @@ def _save_checkpoint(self, model, metrics=None):
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
os.makedirs(signal_dir, exist_ok=True)
paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}"))
if "skip_save_model_weight" not in self.args.unified_checkpoint_config:
if (
"skip_save_model_weight" not in self.args.unified_checkpoint_config
or "remove_master_weight" not in self.args.unified_checkpoint_config
):
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))
if self.args.should_save or self.args.use_expert_parallel:
if not self.args.use_hybrid_parallel:
Expand Down Expand Up @@ -2594,7 +2597,10 @@ def _save_checkpoint(self, model, metrics=None):
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
os.makedirs(signal_dir, exist_ok=True)
paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}"))
if "skip_save_model_weight" not in self.args.unified_checkpoint_config:
if (
"skip_save_model_weight" not in self.args.unified_checkpoint_config
or "remove_master_weight" not in self.args.unified_checkpoint_config
):
paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}"))

self.runtime_timer.stop()
Expand Down

0 comments on commit 9b62712

Please sign in to comment.