diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 9d7b5fae..40bd836f 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -123,6 +123,17 @@ def _init_cleanup(self): if self.dvc_file and os.path.exists(self.dvc_file): os.remove(self.dvc_file) + def _init_check_dvcyaml_overlap(self): + for stage in self._dvc_repo.index.stages: + for out in stage.outs: + if str(out.fs_path) in str(Path(self.dvc_file).absolute()): + msg = ( + f"'{self.dvc_file}' is in outputs of stage " + f"'{stage.addressing}'.\n" + f"Remove it from outputs to make DVCLive work as expected." + ) + logger.warning(msg) + def _init_dvc(self): from dvc.scm import NoSCM @@ -157,6 +168,9 @@ def _init_dvc(self): self._save_dvc_exp = False return + if self._dvcyaml: + self._init_check_dvcyaml_overlap() + if self._inside_dvc_exp: return diff --git a/tests/test_dvc.py b/tests/test_dvc.py index b945bf72..ef0d99f8 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -321,3 +321,22 @@ def test_no_scm_repo(tmp_dir, mocker): live = Live(save_dvc_exp=True) assert live._save_dvc_exp is False + + +@pytest.mark.parametrize("dvcyaml", [True, False]) +def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyaml): + logger = mocker.patch("dvclive.live.logger") + dvc_stage = mocker.MagicMock() + dvc_stage.addressing = "train" + dvc_out = mocker.MagicMock() + dvc_out.fs_path = tmp_dir / "dvclive" + dvc_stage.outs = [dvc_out] + mocked_dvc_repo.index.stages = [dvc_stage] + live = Live(dvcyaml=dvcyaml) + + if dvcyaml: + msg = f"'{live.dvc_file}' is in outputs of stage 'train'.\n" + msg += "Remove it from outputs to make DVCLive work as expected." + logger.warning.assert_called_with(msg) + else: + logger.warning.assert_not_called()