From 74f06094e8073e15a0d014436e28f3f465aa0191 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 3 Jul 2023 15:12:38 -0400 Subject: [PATCH 1/8] auto save model in lightning --- src/dvclive/lightning.py | 19 ++++++++++++++++++- src/dvclive/live.py | 40 ++++++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index cfb8375e..968bfb99 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -1,12 +1,13 @@ # ruff: noqa: ARG002 import inspect -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from lightning.fabric.utilities.logger import ( _convert_params, _sanitize_callable_params, _sanitize_params, ) +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.utilities import rank_zero_only from torch import is_tensor @@ -38,6 +39,7 @@ def __init__( self, run_name: Optional[str] = "dvclive_run", prefix="", + log_model: Union[str, bool] = False, experiment=None, dir: Optional[str] = None, # noqa: A002 resume: bool = False, @@ -60,6 +62,8 @@ def __init__( if report == "notebook": # Force Live instantiation self.experiment # noqa: B018 + self._log_model = log_model + self._checkpoint_callback: Optional[ModelCheckpoint] = None @property def name(self): @@ -119,6 +123,19 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): self.experiment._latest_studio_step -= 1 # noqa: SLF001 self.experiment.next_step() + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + self._checkpoint_callback = checkpoint_callback + if self._log_model == "all": + self.experiment.log_artifact(checkpoint_callback.dirpath) + @rank_zero_only def finalize(self, status: str) -> None: + checkpoint_callback = self._checkpoint_callback + # Save model checkpoints. + if self._log_model is True: + self.experiment.log_artifact(checkpoint_callback.dirpath) + # Log best model. + if self._log_model in (True, "all"): + best_model_path = checkpoint_callback.best_model_path + self.experiment.log_artifact(best_model_path, name="best", cache=False) self.experiment.end() diff --git a/src/dvclive/live.py b/src/dvclive/live.py index d70b6a23..deb3b968 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -410,10 +410,11 @@ def log_artifact( path: StrPath, type: Optional[str] = None, # noqa: A002 name: Optional[str] = None, - desc: Optional[str] = None, # noqa: ARG002 - labels: Optional[List[str]] = None, # noqa: ARG002 - meta: Optional[Dict[str, Any]] = None, # noqa: ARG002 + desc: Optional[str] = None, + labels: Optional[List[str]] = None, + meta: Optional[Dict[str, Any]] = None, copy: bool = False, + cache: bool = True, ): """Tracks a local file or directory with DVC""" if not isinstance(path, (str, Path)): @@ -425,21 +426,24 @@ def log_artifact( if copy: path = clean_and_copy_into(path, self.artifacts_dir) - self.cache(path) - - name = name or Path(path).stem - if name_is_compatible(name): - self._artifacts[name] = { - k: v - for k, v in locals().items() - if k in ("path", "type", "desc", "labels", "meta") and v is not None - } - else: - logger.warning( - "Can't use '%s' as artifact name (ID)." - " It will not be included in the `artifacts` section.", - name, - ) + if cache: + self.cache(path) + + if any((type, name, desc, labels, meta)): + name = name or Path(path).stem + if name_is_compatible(name): + self._artifacts[name] = { + k: v + for k, v in locals().items() + if k in ("path", "type", "desc", "labels", "meta") + and v is not None + } + else: + logger.warning( + "Can't use '%s' as artifact name (ID)." + " It will not be included in the `artifacts` section.", + name, + ) def cache(self, path): try: From dbd354b95278ed7509d016d560bfad34bb8e93d7 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 4 Jul 2023 14:53:41 -0400 Subject: [PATCH 2/8] lightning: save model at each checkpoint if save_top_k == -1 --- src/dvclive/lightning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 968bfb99..43f6c747 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -125,7 +125,9 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: self._checkpoint_callback = checkpoint_callback - if self._log_model == "all": + if self._log_model == "all" or ( + self._log_model is True and checkpoint_callback.save_top_k == -1 + ): self.experiment.log_artifact(checkpoint_callback.dirpath) @rank_zero_only From 15932a8400a5202f0d7aa92159dccd9e927e092f Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 5 Jul 2023 08:18:39 -0400 Subject: [PATCH 3/8] add type: model to lightning artifact --- src/dvclive/lightning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 43f6c747..4aaff1b0 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -139,5 +139,7 @@ def finalize(self, status: str) -> None: # Log best model. if self._log_model in (True, "all"): best_model_path = checkpoint_callback.best_model_path - self.experiment.log_artifact(best_model_path, name="best", cache=False) + self.experiment.log_artifact( + best_model_path, name="best", type="model", cache=False + ) self.experiment.end() From 110b9aa575f15b802e6c9effb43375ded4761935 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 11 Jul 2023 14:34:20 -0400 Subject: [PATCH 4/8] lightning: drop unused checkpoints --- src/dvclive/lightning.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 8f33936b..39361fa7 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -1,6 +1,7 @@ # ruff: noqa: ARG002 import inspect -from typing import Any, Dict, Optional, Union +from pathlib import Path +from typing import Any, Dict, List, Optional, Union from lightning.fabric.utilities.logger import ( _convert_params, @@ -9,6 +10,7 @@ ) from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment +from lightning.pytorch.loggers.utilities import _scan_checkpoints from lightning.pytorch.utilities import rank_zero_only from torch import is_tensor @@ -35,7 +37,7 @@ def _should_call_next_step(): class DVCLiveLogger(Logger): - def __init__( + def __init__( # noqa: PLR0913 self, run_name: Optional[str] = "dvclive_run", prefix="", @@ -65,7 +67,9 @@ def __init__( # Force Live instantiation self.experiment # noqa: B018 self._log_model = log_model + self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._all_checkpoint_paths: List[str] = [] @property def name(self): @@ -130,14 +134,14 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: if self._log_model == "all" or ( self._log_model is True and checkpoint_callback.save_top_k == -1 ): - self.experiment.log_artifact(checkpoint_callback.dirpath) + self._save_checkpoints(checkpoint_callback) @rank_zero_only def finalize(self, status: str) -> None: checkpoint_callback = self._checkpoint_callback # Save model checkpoints. if self._log_model is True: - self.experiment.log_artifact(checkpoint_callback.dirpath) + self._save_checkpoints(checkpoint_callback) # Log best model. if self._log_model in (True, "all"): best_model_path = checkpoint_callback.best_model_path @@ -145,3 +149,22 @@ def finalize(self, status: str) -> None: best_model_path, name="best", type="model", cache=False ) self.experiment.end() + + def _scan_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # update model time and append path to list of all checkpoints + for t, p, _, _ in checkpoints: + self._logged_model_time[p] = t + self._all_checkpoint_paths.append(p) + + def _save_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + # drop unused checkpoints + if not self._experiment._resume: # noqa: SLF001 + for p in Path(checkpoint_callback.dirpath).iterdir(): + if str(p) not in self._all_checkpoint_paths: + p.unlink(missing_ok=True) + + # save directory + self.experiment.log_artifact(checkpoint_callback.dirpath) From 48309f5ec54bc779f0d791507d0a85f1aae6b28a Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 12 Jul 2023 18:30:14 -0400 Subject: [PATCH 5/8] lightning: add tests for log_model --- src/dvclive/lightning.py | 14 ++++----- tests/test_frameworks/test_lightning.py | 39 +++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 39361fa7..fef5de26 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -130,7 +130,9 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): self.experiment.next_step() def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: - self._checkpoint_callback = checkpoint_callback + if self._log_model in [True, "all"]: + self._checkpoint_callback = checkpoint_callback + self._scan_checkpoints(checkpoint_callback) if self._log_model == "all" or ( self._log_model is True and checkpoint_callback.save_top_k == -1 ): @@ -138,13 +140,11 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: @rank_zero_only def finalize(self, status: str) -> None: - checkpoint_callback = self._checkpoint_callback - # Save model checkpoints. - if self._log_model is True: - self._save_checkpoints(checkpoint_callback) # Log best model. - if self._log_model in (True, "all"): - best_model_path = checkpoint_callback.best_model_path + if self._checkpoint_callback: + self._scan_checkpoints(self._checkpoint_callback) + self._save_checkpoints(self._checkpoint_callback) + best_model_path = self._checkpoint_callback.best_model_path self.experiment.log_artifact( best_model_path, name="best", type="model", cache=False ) diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 8aa4fbf5..e572b235 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -8,8 +8,9 @@ try: import torch - from pytorch_lightning import LightningModule - from pytorch_lightning.trainer import Trainer + from lightning import LightningModule + from lightning.pytorch import Trainer + from lightning.pytorch.callbacks import ModelCheckpoint from torch import nn from torch.nn import functional as F # noqa: N812 from torch.optim import SGD, Adam @@ -18,7 +19,7 @@ from dvclive import Live from dvclive.lightning import DVCLiveLogger except ImportError: - pytest.skip("skipping pytorch_lightning tests", allow_module_level=True) + pytest.skip("skipping lightning tests", allow_module_level=True) class XORDataset(Dataset): @@ -161,6 +162,38 @@ def test_lightning_kwargs(tmp_dir): assert dvclive_logger.experiment._cache_images is True +@pytest.mark.parametrize("log_model", [False, True, "all"]) +@pytest.mark.parametrize("save_top_k", [1, -1]) +def test_lightning_log_model(tmp_dir, mocker, log_model, save_top_k): + model = LitXOR() + dvclive_logger = DVCLiveLogger(dir="dir", log_model=log_model) + checkpoint = ModelCheckpoint(dirpath="model", save_top_k=save_top_k) + trainer = Trainer( + logger=dvclive_logger, + max_epochs=2, + log_every_n_steps=1, + callbacks=[checkpoint], + ) + log_artifact = mocker.patch.object(dvclive_logger.experiment, "log_artifact") + trainer.fit(model) + + # Check that log_artifact is called. + if log_model is False: + log_artifact.assert_not_called() + elif (log_model is True) and (save_top_k != -1): + # called once to cache, then again to log best artifact + assert log_artifact.call_count == 2 + else: + # once per epoch plus two calls at the end (see above) + assert log_artifact.call_count == 4 + + # Check that checkpoint files does not grow with each run. + num_checkpoints = len(os.listdir(tmp_dir / "model")) + if log_model in [True, "all"]: + trainer.fit(model) + assert len(os.listdir(tmp_dir / "model")) == num_checkpoints + + def test_lightning_steps(tmp_dir, mocker): model = LitXOR() # Handle kwargs passed to Live. From c3a798ad0cad3cb46d98a4c07e74f38e83781cd9 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 19 Jul 2023 17:30:34 -0400 Subject: [PATCH 6/8] clean up warnings and move some to info --- src/dvclive/live.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 309cbcd7..3d305025 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -43,7 +43,6 @@ formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") handler.setFormatter(formatter) logger.addHandler(handler) -logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "INFO").upper()) ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]] @@ -133,9 +132,7 @@ def _init_dvc(self): self._exp_name = os.getenv(env.DVC_EXP_NAME, "") self._inside_dvc_exp = True if self._save_dvc_exp: - logger.warning( - "Ignoring `_save_dvc_exp` because `dvc exp run` is running" - ) + logger.info("Ignoring `_save_dvc_exp` because `dvc exp run` is running") self._save_dvc_exp = False self._dvc_repo = get_dvc_repo() @@ -449,6 +446,19 @@ def log_artifact( ) def cache(self, path): + if self._inside_dvc_exp: + from dvc.exceptions import OutputNotFoundError + + msg = f"Skipping dvc add {path} because `dvc exp run` is running." + try: + self._dvc_repo.find_outs_by_path(path) + msg += " It is already being tracked automatically." + logger.info(msg) + except OutputNotFoundError: + msg += " Add it as a pipeline output to track it." + logger.warn(msg) + return + try: stage = self._dvc_repo.add(str(path)) except Exception as e: # noqa: BLE001 From 0e561fb959bac8a785b352e28a36257992313abb Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 19 Jul 2023 17:40:54 -0400 Subject: [PATCH 7/8] add test --- tests/test_log_artifact.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index a1721a9b..84e534eb 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -215,3 +215,12 @@ def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_ assert load_yaml(live.dvc_file) == { "artifacts": {"model": {"path": "../model.pth", "type": "model"}} } + + +def test_log_artifact_inside_exp(tmp_dir, mocked_dvc_repo): + data = tmp_dir / "data" + data.touch() + with Live() as live: + live._inside_dvc_exp = True + live.log_artifact("data") + mocked_dvc_repo.add.assert_not_called() From 331972f48f7193877d36abba23f08e1d6cf7e9f3 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Sat, 22 Jul 2023 09:15:36 -0400 Subject: [PATCH 8/8] adds tests for log_artifact logger messages --- src/dvclive/live.py | 30 +++++++++++++++++++-------- tests/test_log_artifact.py | 42 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 5f0a06ed..7f3df0de 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -447,16 +447,30 @@ def log_artifact( def cache(self, path): if self._inside_dvc_exp: - from dvc.exceptions import OutputNotFoundError - msg = f"Skipping dvc add {path} because `dvc exp run` is running." - try: - self._dvc_repo.find_outs_by_path(path) - msg += " It is already being tracked automatically." + path_stage = None + for stage in self._dvc_repo.stage.collect(): + for out in stage.outs: + if out.fspath == str(Path(path).absolute()): + path_stage = stage + break + if not path_stage: + msg += ( + "\nTo track it automatically during `dvc exp run`, " + "add it as an output of the pipeline stage." + ) + logger.warning(msg) + elif path_stage.cmd: + msg += "\nIt is already being tracked automatically." logger.info(msg) - except OutputNotFoundError: - msg += " Add it as a pipeline output to track it." - logger.warn(msg) + else: + msg += ( + "\nTo track it automatically during `dvc exp run`:" + f"\n1. Run `dvc exp remove {path_stage.addressing}` " + "to stop tracking it outside the pipeline." + "\n2. Add it as an output of the pipeline stage." + ) + logger.warning(msg) return try: diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index 84e534eb..f38bace3 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -6,6 +6,14 @@ from dvclive import Live from dvclive.serialize import load_yaml +dvcyaml = """ +stages: + train: + cmd: python train.py + outs: + - data +""" + @pytest.mark.parametrize("cache", [True, False]) def test_log_artifact(tmp_dir, dvc_repo, cache): @@ -224,3 +232,37 @@ def test_log_artifact_inside_exp(tmp_dir, mocked_dvc_repo): live._inside_dvc_exp = True live.log_artifact("data") mocked_dvc_repo.add.assert_not_called() + + +@pytest.mark.parametrize("tracked", ["data_source", "stage", None]) +def test_log_artifact_inside_exp_logger(tmp_dir, mocker, dvc_repo, tracked): + logger = mocker.patch("dvclive.live.logger") + if tracked == "data_source": + data = tmp_dir / "data" + data.touch() + dvc_repo.add(data) + elif tracked == "stage": + dvcyaml_path = tmp_dir / "dvc.yaml" + with open(dvcyaml_path, "w") as f: + f.write(dvcyaml) + with Live() as live: + live._inside_dvc_exp = True + live.log_artifact("data") + msg = "Skipping dvc add data because `dvc exp run` is running." + if tracked == "data_source": + msg += ( + "\nTo track it automatically during `dvc exp run`:" + "\n1. Run `dvc exp remove data.dvc`" + "to stop tracking it outside the pipeline." + "\n2. Add it as an output of the pipeline stage." + ) + logger.warning.assert_called_with(msg) + elif tracked == "stage": + msg += "\nIt is already being tracked automatically." + logger.info.assert_called_with(msg) + else: + msg += ( + "\nTo track it automatically during `dvc exp run`, " + "add it as an output of the pipeline stage." + ) + logger.warning.assert_called_with(msg)