Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] update async_save_info in develop #9173

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2297,16 +2297,7 @@
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -2584,10 +2575,9 @@
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
if (
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
and local_rank == 0
and paddle.distributed.get_rank() == 0
and self.args.unified_checkpoint
and "async_save" in self.args.unified_checkpoint_config
):
Expand All @@ -2598,9 +2588,10 @@
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
}
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)
if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite
os.remove(os.path.join(self.args.logging_dir, "async_save_info.json"))
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)

Check warning on line 2594 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2591-L2594

Added lines #L2591 - L2594 were not covered by tests

if self.args.should_save:
if self.tokenizer is not None:
Expand All @@ -2609,7 +2600,17 @@
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

if self.args.unified_checkpoint:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config

Check warning on line 2603 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2603

Added line #L2603 was not covered by tests
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

Check warning on line 2606 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2605-L2606

Added lines #L2605 - L2606 were not covered by tests

self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup

Check warning on line 2612 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2611-L2612

Added lines #L2611 - L2612 were not covered by tests

return

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
Expand Down
Loading