Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce optuna.artifacts to the PyTorch checkpoint example #280

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

kAIto47802
Copy link
Contributor

Motivation

Currently, the PyTorch checkpoint example is using local file system to save and manage checkpoints, not yet reflecting the recent optuna.artifacts functionalities.

Description of the changes

  • Introduced optuna.artifacts.
  • Removed the use of local file system.

@kAIto47802 kAIto47802 changed the title Introduce artifact store to the PyTorch checkpoint example Introduce optuna.artifacts to the PyTorch checkpoint example Sep 24, 2024
@HideakiImamura
Copy link
Member

@nabenabe0928 Could you review this PR?

Copy link

github-actions bot commented Oct 7, 2024

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Oct 7, 2024
@not522 not522 assigned not522 and unassigned nabenabe0928 Oct 9, 2024
Copy link
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR! Could you check my comments?

pytorch/pytorch_checkpoint.py Outdated Show resolved Hide resolved
checkpoint = torch.load(checkpoint_path)
if trial_number is not None:
study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db")
artifact_id = study.trials[trial_number].user_attrs["artifact_id"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the process is terminated before the first checkpoint, the artifact will not be saved, so check if it exists.

@@ -159,9 +158,15 @@ def objective(trial):
"optimizer_state_dict": optimizer.state_dict(),
"accuracy": accuracy,
},
tmp_checkpoint_path,
"./tmp_model.pt",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change the path of checkpoint for each trial? If we run this script with multi-process, the saved models can be broken by other processes.

@github-actions github-actions bot removed the stale Exempt from stale bot labeling. label Oct 9, 2024
Copy link

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Oct 17, 2024
@nabenabe0928 nabenabe0928 removed the stale Exempt from stale bot labeling. label Oct 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants