Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto save model in lightning #613

Closed
wants to merge 8 commits into from
Closed

auto save model in lightning #613

wants to merge 8 commits into from

Conversation

dberenbaum
Copy link
Collaborator

@dberenbaum dberenbaum commented Jul 3, 2023

This PR auto logs models with dvclive.lightning.DVCLiveLogger(log_model=True):

trainer = pl.Trainer(logger=[DVCLiveLogger(log_model=True)])

log_model follows the conventions in mlflow and wandb:

  • False saves no models (this is the default).
  • True saves all model checkpoints at the end of training.
  • "all" saves all model checkpoints whenever a model checkpoint is saved.

If log_model is True or "all", dvclive caches the entire checkpoints folder.

Dvclive will also add a model artifact named "best" at the end of training that references the best model checkpoint. (edit: this resembles the best alias in wandb)

Edit: example dvclive/dvc.yaml output:

artifacts:
  best:
    path: ../DvcLiveLogger/dvclive_run/checkpoints/epoch=1-step=4-v1.ckpt
    type: model

To support this, log_artifact was also changed:

  • Artifacts will only be added to dvc.yaml:artifacts if some metadata is provided (type, name, desc, labels, or meta). This is a breaking change, but I can't see how anyone is making use of this without any metadata since it won't be used by the model registry.
  • Added cache kwarg to log_artifact (defaults to True) so that it's possible to add the artifact metadata without caching the object.

Related:

To do:

@dberenbaum
Copy link
Collaborator Author

Overall, I felt opening a PR with the desired behavior would be more effective than explaining and discussing in an issue. I hope this will help resolve some of the rough edges around saving models and that we can work through the other framework callbacks to implement similar functionality that works with the existing framework conventions and resembles mlflow, wandb, etc.

@dberenbaum
Copy link
Collaborator Author

One thing to note: lightning will not overwrite existing files or clean up between runs. Instead, it will append a version number, so if you run the same code repeatedly, you will end up with a directory that tracks your entire history of model checkpoints instead of only the latest run:

$ tree DvcLiveLogger
DvcLiveLogger
└── dvclive_run
    ├── checkpoints
    │   ├── epoch=0-step=2-v1.ckpt
    │   ├── epoch=0-step=2-v2.ckpt
    │   ├── epoch=0-step=2.ckpt
    │   ├── epoch=1-step=32-v1.ckpt
    │   ├── epoch=1-step=32-v2.ckpt
    │   ├── epoch=1-step=32.ckpt
    │   ├── epoch=1-step=4.ckpt
    │   ├── epoch=2-step=6.ckpt
    │   ├── epoch=4-step=10-v1.ckpt
    │   ├── epoch=4-step=10-v2.ckpt
    │   ├── epoch=4-step=10-v3.ckpt
    │   ├── epoch=4-step=10-v4.ckpt
    │   ├── epoch=4-step=10-v5.ckpt
    │   ├── epoch=4-step=10-v6.ckpt
    │   └── epoch=4-step=10.ckpt
    └── checkpoints.dvc

If you are running a pipeline, this is probably fine since you can control if you want to delete that directory each time. We might also want to consider dropping the existing checkpoints directory in the dvclive callback if resume=False.

@daavoo

This comment was marked as resolved.

@dberenbaum

This comment was marked as resolved.

@dberenbaum
Copy link
Collaborator Author

I think this gets us to a good place with logging models in lightning. In fact, comparing to other trackers, it feels a bit easier to manage the models this way in dvc. On a different machine, you can pull the lightning checkpoints dir and keep using lightning methods to load those checkpoints. With other trackers, once you are on a different machine, the only way to load models is using the experiment tracker's api.

# Save model checkpoints.
if self._log_model is True:
self.experiment.log_artifact(checkpoint_callback.dirpath)
# Log best model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of creating a copy in "dvclive" folder (or in the checkpoints folder itself), at least for the best?

It seems that we would be changing the path of the registered model between experiments in the current behavior

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, which I guess is also what other trackers do AFAICT? Do you think the path matters? Maybe it makes it easier to dvc get later, although we could make that work by the artifact name. No strong opinion from me.

@daavoo
Copy link
Contributor

daavoo commented Jul 5, 2023

I think this gets us to a good place with logging models in lightning. In fact, comparing to other trackers, it feels a bit easier to manage the models this way in dvc. On a different machine, you can pull the lightning checkpoints dir and keep using lightning methods to load those checkpoints. With other trackers, once you are on a different machine, the only way to load models is using the experiment tracker's api.

I am ok with moving forward in this direction and prioritizing similar behavior in the other (most used) frameworks.

We should invest some time in properly documenting the behavior and expected workflow (how to use the dvc-tracked artifacts later), though

@daavoo
Copy link
Contributor

daavoo commented Jul 5, 2023

If you are running a pipeline, this is probably fine since you can control if you want to delete that directory each time. We might also want to consider dropping the existing checkpoints directory in the dvclive callback if resume=False.

Didn't look in details, but seems like the other loggers use _scan_checkpoints to only track the ones related to the current experiment

@dberenbaum
Copy link
Collaborator Author

Didn't look in details, but seems like the other loggers use _scan_checkpoints to only track the ones related to the current experiment

Great idea. I'll look into it.

@dberenbaum
Copy link
Collaborator Author

Lightning warns if the directory is not empty:

UserWarning: Checkpoint directory /Users/dave/Code/lstm_seq2seq/model exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

_scan_checkpoints only returns the latest checkpoint even if you use something like ModelCheckpoint(save_top_k=-1), so I had to save the checkpoint from each scan and then drop the rest of the files in the directory (unless resume=True).

Removing the files only happens after the checkpoint is saved, so sometimes the first checkpoint will still get a version number like:

$ tree model
model
├── epoch=0-step=2-v1.ckpt # saved this checkpoint before previous one was dropped
├── epoch=1-step=4.ckpt
├── epoch=2-step=6.ckpt
├── epoch=3-step=8.ckpt
└── epoch=4-step=10.ckpt

Overall, this works and probably meets most user's expectations, so I think we should keep it, but I don't feel strongly that it outweighs the added complexity or potential surprise that dvclive is deleting model checkpoints.

@dberenbaum dberenbaum marked this pull request as ready for review July 12, 2023 22:31
@dberenbaum dberenbaum requested a review from daavoo July 14, 2023 13:15
@dberenbaum
Copy link
Collaborator Author

ping @daavoo

@codecov-commenter
Copy link

Codecov Report

Patch coverage: 3.63% and project coverage change: -1.42 ⚠️

Comparison is base (469e39e) 89.47% compared to head (a9e4587) 88.06%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #613      +/-   ##
==========================================
- Coverage   89.47%   88.06%   -1.42%     
==========================================
  Files          44       43       -1     
  Lines        2994     3042      +48     
  Branches      250      260      +10     
==========================================
  Hits         2679     2679              
- Misses        276      324      +48     
  Partials       39       39              
Impacted Files Coverage Δ
src/dvclive/lightning.py 0.00% <0.00%> (ø)
tests/test_frameworks/test_lightning.py 6.09% <9.09%> (-0.80%) ⬇️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@daavoo daavoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a high level, the code and description make sense to me.

I have not actually tried in a project the different options, but the test looks reasonable so trusting that.

Comment on lines +178 to +179
if str(p) not in self._all_checkpoint_paths:
p.unlink(missing_ok=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be clear about this in the docs

@daavoo
Copy link
Contributor

daavoo commented Jul 21, 2023

I think I would like to have a best (or best_only) option for log_model but we can discuss separately

@dberenbaum dberenbaum closed this Jul 25, 2023
@dberenbaum dberenbaum reopened this Jul 25, 2023
@dberenbaum dberenbaum closed this Jul 25, 2023
@dberenbaum dberenbaum deleted the lightning-model branch July 25, 2023 20:34
@dberenbaum dberenbaum restored the lightning-model branch July 25, 2023 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants