diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 86667a87b2..0f6a1ecf25 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -2,6 +2,7 @@ import asyncio import logging +from collections import defaultdict from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, overload @@ -39,6 +40,7 @@ class _ShuffleRunManager: closed: bool _active_runs: dict[ShuffleId, ShuffleRun] _runs: set[ShuffleRun] + _refresh_locks: defaultdict[ShuffleId, asyncio.Lock] #: Mapping of shuffle IDs to the largest stale run ID. #: This is used to prevent race conditions between fetching shuffle run data #: from the scheduler and failing a shuffle run. @@ -51,6 +53,7 @@ def __init__(self, plugin: ShuffleWorkerPlugin) -> None: self.closed = False self._active_runs = {} self._runs = set() + self._refresh_locks = defaultdict(asyncio.Lock) self._stale_run_ids = {} self._runs_cleanup_condition = asyncio.Condition() self._plugin = plugin @@ -117,20 +120,21 @@ async def get_with_run_id(self, shuffle_id: ShuffleId, run_id: int) -> ShuffleRu ShuffleClosedError If the run manager has been closed """ - shuffle_run = self._active_runs.get(shuffle_id, None) - if shuffle_run is None or shuffle_run.run_id < run_id: - shuffle_run = await self._refresh(shuffle_id=shuffle_id) - - if shuffle_run.run_id > run_id: - raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}") - elif shuffle_run.run_id < run_id: - raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}") - - if self.closed: - raise ShuffleClosedError(f"{self} has already been closed") - if shuffle_run._exception: - raise shuffle_run._exception - return shuffle_run + async with self._refresh_locks[shuffle_id]: + shuffle_run = self._active_runs.get(shuffle_id, None) + if shuffle_run is None or shuffle_run.run_id < run_id: + shuffle_run = await self._refresh(shuffle_id=shuffle_id) + + if shuffle_run.run_id > run_id: + raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}") + elif shuffle_run.run_id < run_id: + raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}") + + if self.closed: + raise ShuffleClosedError(f"{self} has already been closed") + if shuffle_run._exception: + raise shuffle_run._exception + return shuffle_run async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun: """Get or create a shuffle matching the ID and data spec. @@ -144,13 +148,14 @@ async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun: key: Task key triggering the function """ - shuffle_run = self._active_runs.get(spec.id, None) - if shuffle_run is None: - shuffle_run = await self._refresh( - shuffle_id=spec.id, - spec=spec, - key=key, - ) + async with self._refresh_locks[spec.id]: + shuffle_run = self._active_runs.get(spec.id, None) + if shuffle_run is None: + shuffle_run = await self._refresh( + shuffle_id=spec.id, + spec=spec, + key=key, + ) if self.closed: raise ShuffleClosedError(f"{self} has already been closed") diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 1856d87f8d..615f83b6b0 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -2704,18 +2704,6 @@ async def test_unpack_gets_rescheduled_from_non_participating_worker(c, s, a): dd.assert_eq(result, expected) -class BlockedBarrierShuffleSchedulerPlugin(ShuffleSchedulerPlugin): - def __init__(self, scheduler: Scheduler): - super().__init__(scheduler) - self.in_barrier = asyncio.Event() - self.block_barrier = asyncio.Event() - - async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: - self.in_barrier.set() - await self.block_barrier.wait() - return await super().barrier(id, run_id, consistent) - - class FlakyConnectionPool(ConnectionPool): def __init__(self, *args, failing_connects=0, **kwargs): self.attempts = 0 @@ -2955,3 +2943,86 @@ async def test_dont_downscale_participating_workers(c, s, a, b): workers_to_close = s.workers_to_close(n=2) assert len(workers_to_close) == 2 + + +class RequestCountingSchedulerPlugin(ShuffleSchedulerPlugin): + def __init__(self, scheduler): + super().__init__(scheduler) + self.counts = defaultdict(int) + + def get(self, *args, **kwargs): + self.counts["get"] += 1 + return super().get(*args, **kwargs) + + def get_or_create(self, *args, **kwargs): + self.counts["get_or_create"] += 1 + return super().get_or_create(*args, **kwargs) + + +class PostFetchBlockingManager(_ShuffleRunManager): + def __init__(self, plugin): + super().__init__(plugin) + self.in_fetch = asyncio.Event() + self.block_fetch = asyncio.Event() + + async def _fetch(self, *args, **kwargs): + result = await super()._fetch(*args, **kwargs) + self.in_fetch.set() + await self.block_fetch.wait() + return result + + +@mock.patch( + "distributed.shuffle.ShuffleSchedulerPlugin", + RequestCountingSchedulerPlugin, +) +@mock.patch( + "distributed.shuffle._worker_plugin._ShuffleRunManager", + PostFetchBlockingManager, +) +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + config={ + "distributed.scheduler.allowed-failures": 0, + "distributed.p2p.comm.message-size-limit": "10 B", + }, +) +async def test_workers_do_not_spam_get_requests(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-02-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + s.remove_plugin("shuffle") + shuffle_extS = RequestCountingSchedulerPlugin(s) + shuffle_extA = a.plugins["shuffle"] + shuffle_extB = b.plugins["shuffle"] + + with dask.config.set({"dataframe.shuffle.method": "p2p"}): + out = df.shuffle("x", npartitions=100) + out = c.compute(out.x.size) + + shuffle_id = await wait_until_new_shuffle_is_initialized(s) + key = barrier_key(shuffle_id) + await shuffle_extA.shuffle_runs.in_fetch.wait() + await shuffle_extB.shuffle_runs.in_fetch.wait() + + shuffle_extA.shuffle_runs.block_fetch.set() + + barrier_task = s.tasks[key] + while any( + ts.state not in ("processing", "memory") for ts in barrier_task.dependencies + ): + await asyncio.sleep(0.1) + shuffle_extB.shuffle_runs.block_fetch.set() + await out + + assert sum(shuffle_extS.counts.values()) == 2 + + del out + + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s)