Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auto save model in lightning #613

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/dvclive/lightning.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# ruff: noqa: ARG002
import inspect
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from lightning.fabric.utilities.logger import (
_convert_params,
_sanitize_callable_params,
_sanitize_params,
)
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only
from torch import is_tensor
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
self,
run_name: Optional[str] = "dvclive_run",
prefix="",
log_model: Union[str, bool] = False,
experiment=None,
dir: Optional[str] = None, # noqa: A002
resume: bool = False,
Expand All @@ -60,6 +62,8 @@ def __init__(
if report == "notebook":
# Force Live instantiation
self.experiment # noqa: B018
self._log_model = log_model
self._checkpoint_callback: Optional[ModelCheckpoint] = None

@property
def name(self):
Expand Down Expand Up @@ -119,6 +123,23 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
self.experiment._latest_studio_step -= 1 # noqa: SLF001
self.experiment.next_step()

def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
self._checkpoint_callback = checkpoint_callback
if self._log_model == "all" or (
self._log_model is True and checkpoint_callback.save_top_k == -1
):
self.experiment.log_artifact(checkpoint_callback.dirpath)

@rank_zero_only
def finalize(self, status: str) -> None:
checkpoint_callback = self._checkpoint_callback
# Save model checkpoints.
if self._log_model is True:
self.experiment.log_artifact(checkpoint_callback.dirpath)
# Log best model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of creating a copy in "dvclive" folder (or in the checkpoints folder itself), at least for the best?

It seems that we would be changing the path of the registered model between experiments in the current behavior

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, which I guess is also what other trackers do AFAICT? Do you think the path matters? Maybe it makes it easier to dvc get later, although we could make that work by the artifact name. No strong opinion from me.

if self._log_model in (True, "all"):
best_model_path = checkpoint_callback.best_model_path
self.experiment.log_artifact(
best_model_path, name="best", type="model", cache=False
)
self.experiment.end()
40 changes: 22 additions & 18 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,11 @@ def log_artifact(
path: StrPath,
type: Optional[str] = None, # noqa: A002
name: Optional[str] = None,
desc: Optional[str] = None, # noqa: ARG002
labels: Optional[List[str]] = None, # noqa: ARG002
meta: Optional[Dict[str, Any]] = None, # noqa: ARG002
desc: Optional[str] = None,
labels: Optional[List[str]] = None,
meta: Optional[Dict[str, Any]] = None,
copy: bool = False,
cache: bool = True,
):
"""Tracks a local file or directory with DVC"""
if not isinstance(path, (str, Path)):
Expand All @@ -425,21 +426,24 @@ def log_artifact(
if copy:
path = clean_and_copy_into(path, self.artifacts_dir)

self.cache(path)

name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta") and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)
if cache:
self.cache(path)

if any((type, name, desc, labels, meta)):
name = name or Path(path).stem
if name_is_compatible(name):
self._artifacts[name] = {
k: v
for k, v in locals().items()
if k in ("path", "type", "desc", "labels", "meta")
and v is not None
}
else:
logger.warning(
"Can't use '%s' as artifact name (ID)."
" It will not be included in the `artifacts` section.",
name,
)

def cache(self, path):
try:
Expand Down