Skip to content

Commit

Permalink
dvclive: lightning log_model (#4714)
Browse files Browse the repository at this point in the history
* dvclive: lightning log_model

* expand lightning log_model examples

* drop mention of save_dvc_exp in log_artifact

* tweak log_artifact language

* Update content/docs/dvclive/live/log_artifact.md

---------

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
  • Loading branch information
Dave Berenbaum and daavoo authored Jul 31, 2023
1 parent bff0716 commit dbc446e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 36 deletions.
23 changes: 15 additions & 8 deletions content/docs/dvclive/live/log_artifact.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,21 @@ with Live() as live:

## Description

Uses `dvc add` to track `path` with DVC, generating a `{path}.dvc` file. When
combined with [`save_dvc_exp=True`](/doc/dvclive#initialize-dvclive), it will
ensure that `{path}.dvc` is included in the experiment.
Log `path`, saving its contents to DVC storage. Also annotate with any included
metadata fields (for example, to be consumed in [Studio model registry] or
automation scenarios).

If `cache=True` (which is the default), uses `dvc add` to [track] `path` with
DVC, saving it to the DVC <abbr>cache</abbr> and generating a `{path}.dvc` file
that acts as a pointer to the cached data.

If `Live` was initialized with `dvcyaml=True` (which is the default) and you
include any of the optional metadata fields (`type`, `name`, `desc`, `labels`,
`meta`), it will add an
[artifact](/doc/user-guide/project-structure/dvcyaml-files#artifacts) and all
the metadata passed as arguments to the corresponding `dvc.yaml`. Passing
`type="model"` will mark it as a `model` for DVC and will make it appear in
[Studio Model Registry](/doc/studio).
[Studio model registry].

## Parameters

Expand All @@ -71,12 +75,15 @@ the metadata passed as arguments to the corresponding `dvc.yaml`. Passing
artifact. Useful if you don't want to track the original path in your repo
(for example, it is outside the repo or in a Git-ignored directory).

- `cache` - <abbr>cache</abbr> the files with DVC to
[track](/doc/dvclive/how-it-works#track-large-artifacts-with-dvc) them outside
of Git. Defaults to `True`, but set to `False` if you want to annotate
metadata about the artifact without storing a copy in the DVC cache.
- `cache` - <abbr>cache</abbr> the files with DVC to [track] them outside of
Git. Defaults to `True`, but set to `False` if you want to annotate metadata
about the artifact without storing a copy in the DVC cache.

## Exceptions

- `dvclive.error.InvalidDataTypeError` - thrown if the provided `path` does not
have a supported type.

[track]: /doc/dvclive/how-it-works#track-large-artifacts-with-dvc
[Studio model registry]:
/doc/studio/user-guide/model-registry/what-is-a-model-registry
79 changes: 61 additions & 18 deletions content/docs/dvclive/ml-frameworks/pytorch-lightning.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ checkpointing at all as described in the

- `prefix` - (`None` by default) - string that adds to each metric name.

- `log_model` - (`False` by default) - use
[`live.log_artifact()`](/doc/dvclive/live/log_artifact) to log checkpoints
created by [`ModelCheckpoint`]. See
[Log model checkpoints](#log-model-checkpoints).

- if `log_model == 'all'`, checkpoints are logged during training.

- if `log_model == True`, checkpoints are logged at the end of training,
except when `save_top_k == -1` which also logs every checkpoint during
training.

- if `log_model == False` (default), no checkpoint is logged.

- `experiment` - (`None` by default) - [`Live`](/doc/dvclive/live) object to be
used instead of initializing a new one.

Expand All @@ -69,6 +82,50 @@ checkpointing at all as described in the

## Examples

### Log model checkpoints

Use `log_model` to save the checkpoints (it will use `Live.log_artifact()`
internally to save those). At the end of training, DVCLive will annotate the
[`best_model_path`][`ModelCheckpoint`] with name `best` (for example, to be
consumed in [Studio model registry] or automation scenarios).

- Save updates to the checkpoints directory at the end of training:

```python
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(save_dvc_exp=True, log_model=True)
trainer = Trainer(logger=logger)
trainer.fit(model)
```

- Save updates to the checkpoints directory whenever a new checkpoint is saved:

```python
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(save_dvc_exp=True, log_model="all")
trainer = Trainer(logger=logger)
trainer.fit(model)
```

- Use a custom `ModelCheckpoint`:

```python
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(save_dvc_exp=True, log_model=True),
checkpoint_callback = ModelCheckpoint(
dirpath="model",
monitor="val_acc",
mode="max",
)
trainer = Trainer(logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model)
```

### Passing additional DVCLive arguments

- Using `experiment` to pass an existing [`Live`] instance.

```python
Expand All @@ -93,24 +150,6 @@ trainer = Trainer(
trainer.fit(model)
```

- Using [`live.log_artifact()`](/doc/dvclive/live/log_artifact) to save the
[best checkpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html).

```python
with Live(save_dvc_exp=True) as live:
checkpoint = ModelCheckpoint(dirpath="mymodel")
trainer = Trainer(
logger=DVCLiveLogger(experiment=live),
callbacks=checkpoint
)
trainer.fit(model)
live.log_artifact(
checkpoint.best_model_path,
type="model",
name="lightning-model"
)
```

## Output format

Each metric will be logged to:
Expand Down Expand Up @@ -140,3 +179,7 @@ dvclive/metrics/train/epoch/metric.tsv
```

[`live`]: /doc/dvclive/live
[studio model registry]:
/doc/studio/user-guide/model-registry/what-is-a-model-registry
[`ModelCheckpoint`]:
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
13 changes: 3 additions & 10 deletions content/docs/start/experiments/experiment-tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,13 @@ from dvclive import Live
from dvclive.lightning import DVCLiveLogger

...
with Live(save_dvc_exp=True) as live:
checkpoint = ModelCheckpoint(dirpath="mymodel")
trainer = Trainer(
logger=DVCLiveLogger(
experiment=live
),
callbacks=checkpoint
save_dvc_exp=True,
log_model=True
)
)
trainer.fit(model)
live.log_artifact(
checkpoint.best_model_path,
type="model",
name="lightning-model"
)
```

</tab>
Expand Down

0 comments on commit dbc446e

Please sign in to comment.