From 9b6271271b875ac252cda46cb0090ae21df2a5b2 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Fri, 20 Dec 2024 13:25:55 +0800 Subject: [PATCH] [Trainer] update remove_master_weight (#9640) --- paddlenlp/trainer/trainer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index d4f832820b3a..306d0ab4a45e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2557,7 +2557,10 @@ def _save_checkpoint(self, model, metrics=None): global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) - if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + if ( + "skip_save_model_weight" not in self.args.unified_checkpoint_config + or "remove_master_weight" not in self.args.unified_checkpoint_config + ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: @@ -2594,7 +2597,10 @@ def _save_checkpoint(self, model, metrics=None): global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) - if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + if ( + "skip_save_model_weight" not in self.args.unified_checkpoint_config + or "remove_master_weight" not in self.args.unified_checkpoint_config + ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) self.runtime_timer.stop()