Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dvclive: lightning log_model #4714

Merged
merged 5 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions content/docs/dvclive/ml-frameworks/pytorch-lightning.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ checkpointing at all as described in the

- `prefix` - (`None` by default) - string that adds to each metric name.

- `log_model` - (`False` by default) - use
[`live.log_artifact()`](/doc/dvclive/live/log_artifact) to log checkpoints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log checkpoints - does it means that we might have multiple records? per checkpoint?

created by
[`ModelCheckpoint`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I need to use ModelCheckpoint along side? Ot it's being used internally?

and annotate the best checkpoint with `type=model` and `name=best` for use in
[Studio model registry]. DVCLive will <abbr>cache</abbr> the checkpoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me a while to understand the part about cache and being able to recover it. I think most of the ppl will be confused by this (I expect they don't know much about DVC mechanics). Two suggestions:

  • remove it
  • try to be more explicit: DVCLive saves checkpoints with DVC - it means they are cached in DVC cache, also could be pushed to remote storage, etc. It also means that DVCLive drops the directory (it's safe) when it runs a new experiments. Model weights for any experiment are stored in DVC and can be restored ...

Btw, what happens, if save_dvc_exp is False? How do I recover those?

directory and delete checkpoints from previous <abbr>experiments</abbr>, but
you can recover the checkpoints from any experiment using DVC.

- if `log_model == 'all'`, checkpoints are logged during training.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all checkpoints are logged during training?

also, what does it practically mean? does it create a record per each checkpoint in dvc.yaml?


- if `log_model == True`, checkpoints are logged at the end of training,
except when `save_top_k == -1` which also logs every checkpoint during
training.

- if `log_model == False` (default), no checkpoint is logged.

- `experiment` - (`None` by default) - [`Live`](/doc/dvclive/live) object to be
used instead of initializing a new one.

Expand All @@ -69,6 +86,17 @@ checkpointing at all as described in the

## Examples

- Using `log_model` to save the checkpoints and annotate the best one for use in
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show an output / end result of it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like actual dvc.yaml and directory structure?

[Studio model registry].

```python
from dvclive.lightning import DVCLiveLogger

trainer = Trainer(
daavoo marked this conversation as resolved.
Show resolved Hide resolved
logger=DVCLiveLogger(save_dvc_exp=True, log_model=True))
trainer.fit(model)
```

- Using `experiment` to pass an existing [`Live`] instance.

```python
Expand All @@ -93,24 +121,6 @@ trainer = Trainer(
trainer.fit(model)
```

- Using [`live.log_artifact()`](/doc/dvclive/live/log_artifact) to save the
[best checkpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html).

```python
with Live(save_dvc_exp=True) as live:
checkpoint = ModelCheckpoint(dirpath="mymodel")
trainer = Trainer(
logger=DVCLiveLogger(experiment=live),
callbacks=checkpoint
)
trainer.fit(model)
live.log_artifact(
checkpoint.best_model_path,
type="model",
name="lightning-model"
)
```

## Output format

Each metric will be logged to:
Expand Down Expand Up @@ -140,3 +150,5 @@ dvclive/metrics/train/epoch/metric.tsv
```

[`live`]: /doc/dvclive/live
[studio model registry]:
/doc/studio/user-guide/model-registry/what-is-a-model-registry
13 changes: 3 additions & 10 deletions content/docs/start/experiments/experiment-tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,13 @@ from dvclive import Live
from dvclive.lightning import DVCLiveLogger

...
with Live(save_dvc_exp=True) as live:
checkpoint = ModelCheckpoint(dirpath="mymodel")
trainer = Trainer(
logger=DVCLiveLogger(
experiment=live
),
callbacks=checkpoint
save_dvc_exp=True,
log_model=True
)
)
trainer.fit(model)
live.log_artifact(
checkpoint.best_model_path,
type="model",
name="lightning-model"
)
```

</tab>
Expand Down