Skip to content

Commit

Permalink
don't require exp name
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Aug 4, 2023
1 parent 98ae256 commit df1a20b
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
6 changes: 5 additions & 1 deletion src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,12 @@ def mark_dvclive_only_ended():


def get_random_exp_name(scm, baseline_rev):
from dvc.repo.experiments.utils import gen_random_name
from dvc.repo.experiments.utils import (
get_random_exp_name as dvc_get_random_exp_name,
)

return dvc_get_random_exp_name(scm, baseline_rev)
if scm and baseline_rev:
return dvc_get_random_exp_name(scm, baseline_rev)
# TODO: ping studio for list of existing names to check against
return gen_random_name()
5 changes: 4 additions & 1 deletion src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def _init_dvc(self):

self._exp_name = os.getenv(env.DVC_EXP_NAME)
self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV)
if not self._exp_name:
scm = self._dvc_repo.scm if self._dvc_repo else None
self._exp_name = get_random_exp_name(scm, self._baseline_rev)

if self._dvc_repo and self._baseline_rev and self._exp_name:
# `dvc exp` execution
Expand Down Expand Up @@ -162,8 +165,8 @@ def _init_dvc(self):
return

self._baseline_rev = self._dvc_repo.scm.get_rev()

if self._save_dvc_exp:
self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev)
mark_dvclive_only_started()
self._include_untracked.append(self.dir)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
)
else:
assert live._baseline_rev is not None
assert live._exp_name is None
assert live._exp_name is not None
mocked_dvc_repo.experiments.save.assert_not_called()


Expand Down
11 changes: 4 additions & 7 deletions tests/test_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,6 @@ def test_post_to_studio_inside_subdir_dvc_exp(
)


def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post):
assert Live()._studio_events_to_skip == {"start", "data", "done"}
assert not Live(save_dvc_exp=True)._studio_events_to_skip


def test_get_dvc_studio_config_none(mocker):
mocker.patch("dvclive.live.get_dvc_repo", return_value=None)
live = Live()
Expand Down Expand Up @@ -486,11 +481,13 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post):
)


def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
@pytest.mark.parametrize("exp_name", [True, False])
def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_name):
monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN")
monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL")
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40)
monkeypatch.setenv(DVC_EXP_NAME, "bar")
if exp_name:
monkeypatch.setenv(DVC_EXP_NAME, "bar")

live = Live(save_dvc_exp=True)
live.log_param("fooparam", 1)
Expand Down

0 comments on commit df1a20b

Please sign in to comment.