diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index 2e6511d3..ada2cb0d 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -147,8 +147,12 @@ def mark_dvclive_only_ended(): def get_random_exp_name(scm, baseline_rev): + from dvc.repo.experiments.utils import gen_random_name from dvc.repo.experiments.utils import ( get_random_exp_name as dvc_get_random_exp_name, ) - return dvc_get_random_exp_name(scm, baseline_rev) + if scm and baseline_rev: + return dvc_get_random_exp_name(scm, baseline_rev) + # TODO: ping studio for list of existing names to check against + return gen_random_name() diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 963a834a..d28da142 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -130,6 +130,9 @@ def _init_dvc(self): self._exp_name = os.getenv(env.DVC_EXP_NAME) self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV) + if not self._exp_name: + scm = self._dvc_repo.scm if self._dvc_repo else None + self._exp_name = get_random_exp_name(scm, self._baseline_rev) if self._dvc_repo and self._baseline_rev and self._exp_name: # `dvc exp` execution @@ -162,8 +165,8 @@ def _init_dvc(self): return self._baseline_rev = self._dvc_repo.scm.get_rev() + if self._save_dvc_exp: - self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev) mark_dvclive_only_started() self._include_untracked.append(self.dir) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 529c5571..970ff4ff 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -138,7 +138,7 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): ) else: assert live._baseline_rev is not None - assert live._exp_name is None + assert live._exp_name is not None mocked_dvc_repo.experiments.save.assert_not_called() diff --git a/tests/test_studio.py b/tests/test_studio.py index 5b3aedb3..1aa0dc85 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -397,11 +397,6 @@ def test_post_to_studio_inside_subdir_dvc_exp( ) -def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post): - assert Live()._studio_events_to_skip == {"start", "data", "done"} - assert not Live(save_dvc_exp=True)._studio_events_to_skip - - def test_get_dvc_studio_config_none(mocker): mocker.patch("dvclive.live.get_dvc_repo", return_value=None) live = Live() @@ -486,11 +481,13 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) -def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): +@pytest.mark.parametrize("exp_name", [True, False]) +def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_name): monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") + if exp_name: + monkeypatch.setenv(DVC_EXP_NAME, "bar") live = Live(save_dvc_exp=True) live.log_param("fooparam", 1)