Skip to content

Commit

Permalink
post to studio in thread to avoid blocking (#814)
Browse files Browse the repository at this point in the history
* post to studio in thread to avoid blocking

* queue for studio data posts

* fix test_post_to_studio_if_done_skipped

* catch and warn in src/dvclive/studio.py:post_to_studio
  • Loading branch information
Dave Berenbaum authored Apr 19, 2024
1 parent edb5ee3 commit 228e9a8
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 30 deletions.
25 changes: 19 additions & 6 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import math
import os
import shutil
import queue
import tempfile
import threading

from pathlib import Path, PurePath
from typing import Any, Dict, List, Optional, Set, Tuple, Union, TYPE_CHECKING, Literal
Expand Down Expand Up @@ -171,6 +173,7 @@ def __init__(
self._studio_events_to_skip: Set[str] = set()
self._dvc_studio_config: Dict[str, Any] = {}
self._num_points_sent_to_studio: Dict[str, int] = {}
self._studio_queue = None
self._init_studio()

self._system_monitor: Optional[_SystemMonitor] = None # Monitoring thread
Expand Down Expand Up @@ -296,7 +299,7 @@ def _init_studio(self):
self._studio_events_to_skip.add("start")
self._studio_events_to_skip.add("done")
else:
self.post_to_studio("start")
post_to_studio(self, "start")

def _init_report(self):
if self._report_mode not in {None, "html", "notebook", "md"}:
Expand Down Expand Up @@ -428,7 +431,7 @@ def sync(self):

self.make_report()

self.post_to_studio("data")
self.post_data_to_studio()

def next_step(self):
"""
Expand Down Expand Up @@ -880,9 +883,19 @@ def make_dvcyaml(self):
"""
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event: Literal["start", "data", "done"]):
post_to_studio(self, event)
def post_data_to_studio(self):
if not self._studio_queue:
self._studio_queue = queue.Queue()

def worker():
while True:
item = self._studio_queue.get()
post_to_studio(item, "data")
self._studio_queue.task_done()

threading.Thread(target=worker, daemon=True).start()

self._studio_queue.put(self)

def end(self):
"""
Expand Down Expand Up @@ -926,7 +939,7 @@ def end(self):
self.save_dvc_exp()

# Mark experiment as done
self.post_to_studio("done")
post_to_studio(self, "done")

cleanup_dvclive_step_completed()

Expand Down
4 changes: 4 additions & 0 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from pathlib import PureWindowsPath
from typing import TYPE_CHECKING, Literal, Mapping

from dvc.exceptions import DvcException
from dvc_studio_client.config import get_studio_config
from dvc_studio_client.post_live_metrics import post_live_metrics

from .utils import catch_and_warn

if TYPE_CHECKING:
from dvclive.live import Live
from dvclive.serialize import load_yaml
Expand Down Expand Up @@ -96,6 +99,7 @@ def increment_num_points_sent_to_studio(live, plots):
return live


@catch_and_warn(DvcException, logger)
def post_to_studio(live: Live, event: Literal["start", "data", "done"]): # noqa: C901
if event in live._studio_events_to_skip:
return
Expand Down
78 changes: 54 additions & 24 deletions tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dvclive import Live
from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME, DVC_ROOT
from dvclive.plots import Image, Metric
from dvclive.studio import _adapt_image, get_dvc_studio_config
from dvclive.studio import _adapt_image, get_dvc_studio_config, post_to_studio


def get_studio_call(event_type, exp_name, **kwargs):
Expand Down Expand Up @@ -46,7 +46,9 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
)

live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -58,8 +60,10 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
),
)

live.step += 1
live.log_metric("foo", 2)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -72,7 +76,8 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
)

mocked_post.reset_mock()
live.end()
live.save_dvc_exp()
post_to_studio(live, "done")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down Expand Up @@ -118,11 +123,15 @@ def test_post_to_studio_failed_data_request(
error_response.status_code = 400
mocker.patch("requests.post", return_value=error_response)
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

mocked_post = mocker.patch("requests.post", return_value=valid_response)
live.step += 1
live.log_metric("foo", 2)
live.next_step()
live.make_summary()
post_to_studio(live, "data")
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand Down Expand Up @@ -154,6 +163,7 @@ def test_post_to_studio_failed_start_request(
live.next_step()

assert mocked_post.call_count == 1
assert live._studio_events_to_skip == {"start", "data", "done"}


def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
Expand Down Expand Up @@ -210,7 +220,9 @@ def test_post_to_studio_dvc_studio_config(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token"

Expand All @@ -231,7 +243,9 @@ def test_post_to_studio_skip_if_no_token(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

assert mocked_post.call_count == 0

Expand All @@ -241,7 +255,8 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po

live = Live()
live.log_metric("eval/loss", 1)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

plots_path = Path(live.plots_dir)
loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix()
Expand Down Expand Up @@ -269,7 +284,9 @@ def test_post_to_studio_inside_dvc_exp(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
assert "start" not in call_types
Expand All @@ -287,7 +304,8 @@ def test_post_to_studio_inside_subdir(

live = Live()
live.log_metric("foo", 1)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()

Expand Down Expand Up @@ -317,7 +335,8 @@ def test_post_to_studio_inside_subdir_dvc_exp(

live = Live()
live.log_metric("foo", 1)
live.next_step()
live.make_summary()
post_to_studio(live, "data")

foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix()

Expand Down Expand Up @@ -370,7 +389,9 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post):

live = Live()
live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0)))
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix()

Expand Down Expand Up @@ -409,11 +430,13 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post):


def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live = Live()
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("done")
live.log_metric("foo", 1)
live.end()
with Live() as live:
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("done")
live.log_metric("foo", 1)
live.step = 0
live.make_summary()
post_to_studio(live, "data")

mocked_post, _ = mocked_studio_post
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
Expand All @@ -439,8 +462,9 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
)

live.log_metric("foo", 1)
live.make_summary()
post_to_studio(live, "data")

live.next_step()
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand All @@ -452,9 +476,11 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
),
)

live.step += 1
live.log_metric("foo", 2)
live.make_summary()
post_to_studio(live, "data")

live.next_step()
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call(
Expand All @@ -466,7 +492,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post):
),
)

live.end()
post_to_studio(live, "done")
mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
**get_studio_call("done", baseline_sha="0" * 40, exp_name=live._exp_name),
Expand All @@ -485,7 +511,9 @@ def test_post_to_studio_skip_if_no_repo_url(

with Live() as live:
live.log_metric("foo", 1)
live.next_step()
live.step = 0
live.make_summary()
post_to_studio(live, "data")

assert mocked_post.call_count == 0

Expand All @@ -503,7 +531,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
live.step = 0
live.log_metric("foo", 1)
live.log_metric("bar", 0.1)
live.sync()
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand All @@ -521,7 +550,8 @@ def test_post_to_studio_repeat_step(tmp_dir, mocked_dvc_repo, mocked_studio_post
live.log_metric("foo", 2)
live.log_metric("foo", 3)
live.log_metric("bar", 0.2)
live.sync()
live.make_summary()
post_to_studio(live, "data")

mocked_post.assert_called_with(
"https://0.0.0.0/api/live",
Expand Down

0 comments on commit 228e9a8

Please sign in to comment.