Skip to content

Commit

Permalink
clean up warnings and move some to info (#631)
Browse files Browse the repository at this point in the history
* auto save model in lightning

* lightning: save model at each checkpoint if save_top_k == -1

* add type: model to lightning artifact

* lightning: drop unused checkpoints

* lightning: add tests for log_model

* clean up warnings and move some to info

* add test

* clean up warnings and move some to info

* add test

* adds tests for log_artifact logger messages

* remove extraneous import

* Update tests/test_log_artifact.py

* Update src/dvclive/live.py

---------

Co-authored-by: David de la Iglesia Castro <daviddelaiglesiacastro@gmail.com>
  • Loading branch information
Dave Berenbaum and daavoo authored Jul 31, 2023
1 parent 7d29b35 commit 6fd4c4b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,33 @@ def log_artifact(
)

def cache(self, path):
if self._inside_dvc_exp:
msg = f"Skipping dvc add {path} because `dvc exp run` is running."
path_stage = None
for stage in self._dvc_repo.index.stages:
for out in stage.outs:
if out.fspath == str(Path(path).absolute()):
path_stage = stage
break
if not path_stage:
msg += (
"\nTo track it automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning(msg)
elif path_stage.cmd:
msg += "\nIt is already being tracked automatically."
logger.info(msg)
else:
msg += (
"\nTo track it automatically during `dvc exp run`:"
f"\n1. Run `dvc exp remove {path_stage.addressing}` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning(msg)
return

try:
stage = self._dvc_repo.add(str(path))
except Exception as e: # noqa: BLE001
Expand Down
51 changes: 51 additions & 0 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from dvclive import Live
from dvclive.serialize import load_yaml

dvcyaml = """
stages:
train:
cmd: python train.py
outs:
- data
"""


@pytest.mark.parametrize("cache", [True, False])
def test_log_artifact(tmp_dir, dvc_repo, cache):
Expand Down Expand Up @@ -215,3 +223,46 @@ def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_
assert load_yaml(live.dvc_file) == {
"artifacts": {"model": {"path": "../model.pth", "type": "model"}}
}


def test_log_artifact_inside_exp(tmp_dir, mocked_dvc_repo):
data = tmp_dir / "data"
data.touch()
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
mocked_dvc_repo.add.assert_not_called()


@pytest.mark.parametrize("tracked", ["data_source", "stage", None])
def test_log_artifact_inside_exp_logger(tmp_dir, mocker, dvc_repo, tracked):
logger = mocker.patch("dvclive.live.logger")
if tracked == "data_source":
data = tmp_dir / "data"
data.touch()
dvc_repo.add(data)
elif tracked == "stage":
dvcyaml_path = tmp_dir / "dvc.yaml"
with open(dvcyaml_path, "w") as f:
f.write(dvcyaml)
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
msg = "Skipping dvc add data because `dvc exp run` is running."
if tracked == "data_source":
msg += (
"\nTo track it automatically during `dvc exp run`:"
"\n1. Run `dvc exp remove data.dvc` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
elif tracked == "stage":
msg += "\nIt is already being tracked automatically."
logger.info.assert_called_with(msg)
else:
msg += (
"\nTo track it automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)

0 comments on commit 6fd4c4b

Please sign in to comment.