From dbc446ee2aa89dd0f93cad0b66b0f938cc2b81d7 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Mon, 31 Jul 2023 11:59:14 -0400 Subject: [PATCH] dvclive: lightning log_model (#4714) * dvclive: lightning log_model * expand lightning log_model examples * drop mention of save_dvc_exp in log_artifact * tweak log_artifact language * Update content/docs/dvclive/live/log_artifact.md --------- Co-authored-by: David de la Iglesia Castro --- content/docs/dvclive/live/log_artifact.md | 23 ++++-- .../ml-frameworks/pytorch-lightning.md | 79 ++++++++++++++----- .../start/experiments/experiment-tracking.md | 13 +-- 3 files changed, 79 insertions(+), 36 deletions(-) diff --git a/content/docs/dvclive/live/log_artifact.md b/content/docs/dvclive/live/log_artifact.md index b0c736766c..6dab250c52 100644 --- a/content/docs/dvclive/live/log_artifact.md +++ b/content/docs/dvclive/live/log_artifact.md @@ -36,9 +36,13 @@ with Live() as live: ## Description -Uses `dvc add` to track `path` with DVC, generating a `{path}.dvc` file. When -combined with [`save_dvc_exp=True`](/doc/dvclive#initialize-dvclive), it will -ensure that `{path}.dvc` is included in the experiment. +Log `path`, saving its contents to DVC storage. Also annotate with any included +metadata fields (for example, to be consumed in [Studio model registry] or +automation scenarios). + +If `cache=True` (which is the default), uses `dvc add` to [track] `path` with +DVC, saving it to the DVC cache and generating a `{path}.dvc` file +that acts as a pointer to the cached data. If `Live` was initialized with `dvcyaml=True` (which is the default) and you include any of the optional metadata fields (`type`, `name`, `desc`, `labels`, @@ -46,7 +50,7 @@ include any of the optional metadata fields (`type`, `name`, `desc`, `labels`, [artifact](/doc/user-guide/project-structure/dvcyaml-files#artifacts) and all the metadata passed as arguments to the corresponding `dvc.yaml`. Passing `type="model"` will mark it as a `model` for DVC and will make it appear in -[Studio Model Registry](/doc/studio). +[Studio model registry]. ## Parameters @@ -71,12 +75,15 @@ the metadata passed as arguments to the corresponding `dvc.yaml`. Passing artifact. Useful if you don't want to track the original path in your repo (for example, it is outside the repo or in a Git-ignored directory). -- `cache` - cache the files with DVC to - [track](/doc/dvclive/how-it-works#track-large-artifacts-with-dvc) them outside - of Git. Defaults to `True`, but set to `False` if you want to annotate - metadata about the artifact without storing a copy in the DVC cache. +- `cache` - cache the files with DVC to [track] them outside of + Git. Defaults to `True`, but set to `False` if you want to annotate metadata + about the artifact without storing a copy in the DVC cache. ## Exceptions - `dvclive.error.InvalidDataTypeError` - thrown if the provided `path` does not have a supported type. + +[track]: /doc/dvclive/how-it-works#track-large-artifacts-with-dvc +[Studio model registry]: + /doc/studio/user-guide/model-registry/what-is-a-model-registry diff --git a/content/docs/dvclive/ml-frameworks/pytorch-lightning.md b/content/docs/dvclive/ml-frameworks/pytorch-lightning.md index efff941dec..af2fa5e08a 100644 --- a/content/docs/dvclive/ml-frameworks/pytorch-lightning.md +++ b/content/docs/dvclive/ml-frameworks/pytorch-lightning.md @@ -61,6 +61,19 @@ checkpointing at all as described in the - `prefix` - (`None` by default) - string that adds to each metric name. +- `log_model` - (`False` by default) - use + [`live.log_artifact()`](/doc/dvclive/live/log_artifact) to log checkpoints + created by [`ModelCheckpoint`]. See + [Log model checkpoints](#log-model-checkpoints). + + - if `log_model == 'all'`, checkpoints are logged during training. + + - if `log_model == True`, checkpoints are logged at the end of training, + except when `save_top_k == -1` which also logs every checkpoint during + training. + + - if `log_model == False` (default), no checkpoint is logged. + - `experiment` - (`None` by default) - [`Live`](/doc/dvclive/live) object to be used instead of initializing a new one. @@ -69,6 +82,50 @@ checkpointing at all as described in the ## Examples +### Log model checkpoints + +Use `log_model` to save the checkpoints (it will use `Live.log_artifact()` +internally to save those). At the end of training, DVCLive will annotate the +[`best_model_path`][`ModelCheckpoint`] with name `best` (for example, to be +consumed in [Studio model registry] or automation scenarios). + +- Save updates to the checkpoints directory at the end of training: + +```python +from dvclive.lightning import DVCLiveLogger + +logger = DVCLiveLogger(save_dvc_exp=True, log_model=True) +trainer = Trainer(logger=logger) +trainer.fit(model) +``` + +- Save updates to the checkpoints directory whenever a new checkpoint is saved: + +```python +from dvclive.lightning import DVCLiveLogger + +logger = DVCLiveLogger(save_dvc_exp=True, log_model="all") +trainer = Trainer(logger=logger) +trainer.fit(model) +``` + +- Use a custom `ModelCheckpoint`: + +```python +from dvclive.lightning import DVCLiveLogger + +logger = DVCLiveLogger(save_dvc_exp=True, log_model=True), +checkpoint_callback = ModelCheckpoint( + dirpath="model", + monitor="val_acc", + mode="max", +) +trainer = Trainer(logger=logger, callbacks=[checkpoint_callback]) +trainer.fit(model) +``` + +### Passing additional DVCLive arguments + - Using `experiment` to pass an existing [`Live`] instance. ```python @@ -93,24 +150,6 @@ trainer = Trainer( trainer.fit(model) ``` -- Using [`live.log_artifact()`](/doc/dvclive/live/log_artifact) to save the - [best checkpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html). - -```python -with Live(save_dvc_exp=True) as live: - checkpoint = ModelCheckpoint(dirpath="mymodel") - trainer = Trainer( - logger=DVCLiveLogger(experiment=live), - callbacks=checkpoint - ) - trainer.fit(model) - live.log_artifact( - checkpoint.best_model_path, - type="model", - name="lightning-model" - ) -``` - ## Output format Each metric will be logged to: @@ -140,3 +179,7 @@ dvclive/metrics/train/epoch/metric.tsv ``` [`live`]: /doc/dvclive/live +[studio model registry]: + /doc/studio/user-guide/model-registry/what-is-a-model-registry +[`ModelCheckpoint`]: + https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html diff --git a/content/docs/start/experiments/experiment-tracking.md b/content/docs/start/experiments/experiment-tracking.md index efde7dcb3f..9a9c871e36 100644 --- a/content/docs/start/experiments/experiment-tracking.md +++ b/content/docs/start/experiments/experiment-tracking.md @@ -39,20 +39,13 @@ from dvclive import Live from dvclive.lightning import DVCLiveLogger ... -with Live(save_dvc_exp=True) as live: - checkpoint = ModelCheckpoint(dirpath="mymodel") trainer = Trainer( logger=DVCLiveLogger( - experiment=live - ), - callbacks=checkpoint + save_dvc_exp=True, + log_model=True + ) ) trainer.fit(model) - live.log_artifact( - checkpoint.best_model_path, - type="model", - name="lightning-model" - ) ```