Skip to content

Commit

Permalink
updates to dvc exp run output warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Jul 31, 2023
2 parents 388e16a + 331972f commit 036d1ee
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.logger import _scan_checkpoints
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.logger import _scan_checkpoints
from torch import is_tensor
Expand Down
30 changes: 22 additions & 8 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,30 @@ def log_artifact(

def cache(self, path):
if self._inside_dvc_exp:
from dvc.exceptions import OutputNotFoundError

msg = f"Skipping dvc add {path} because `dvc exp run` is running."
try:
self._dvc_repo.find_outs_by_path(path)
msg += " It is already being tracked automatically."
path_stage = None
for stage in self._dvc_repo.stage.collect():
for out in stage.outs:
if out.fspath == str(Path(path).absolute()):
path_stage = stage
break
if not path_stage:
msg += (
"\nTo track it automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning(msg)
elif path_stage.cmd:
msg += "\nIt is already being tracked automatically."
logger.info(msg)
except OutputNotFoundError:
msg += " Add it as a pipeline output to track it."
logger.warn(msg)
else:
msg += (
"\nTo track it automatically during `dvc exp run`:"
f"\n1. Run `dvc exp remove {path_stage.addressing}` "
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning(msg)
return

try:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from dvclive import Live
from dvclive.serialize import load_yaml

dvcyaml = """
stages:
train:
cmd: python train.py
outs:
- data
"""


@pytest.mark.parametrize("cache", [True, False])
def test_log_artifact(tmp_dir, dvc_repo, cache):
Expand Down Expand Up @@ -224,3 +232,37 @@ def test_log_artifact_inside_exp(tmp_dir, mocked_dvc_repo):
live._inside_dvc_exp = True
live.log_artifact("data")
mocked_dvc_repo.add.assert_not_called()


@pytest.mark.parametrize("tracked", ["data_source", "stage", None])
def test_log_artifact_inside_exp_logger(tmp_dir, mocker, dvc_repo, tracked):
logger = mocker.patch("dvclive.live.logger")
if tracked == "data_source":
data = tmp_dir / "data"
data.touch()
dvc_repo.add(data)
elif tracked == "stage":
dvcyaml_path = tmp_dir / "dvc.yaml"
with open(dvcyaml_path, "w") as f:
f.write(dvcyaml)
with Live() as live:
live._inside_dvc_exp = True
live.log_artifact("data")
msg = "Skipping dvc add data because `dvc exp run` is running."
if tracked == "data_source":
msg += (
"\nTo track it automatically during `dvc exp run`:"
"\n1. Run `dvc exp remove data.dvc`"
"to stop tracking it outside the pipeline."
"\n2. Add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)
elif tracked == "stage":
msg += "\nIt is already being tracked automatically."
logger.info.assert_called_with(msg)
else:
msg += (
"\nTo track it automatically during `dvc exp run`, "
"add it as an output of the pipeline stage."
)
logger.warning.assert_called_with(msg)

0 comments on commit 036d1ee

Please sign in to comment.