Skip to content

Commit

Permalink
Add retry to shuffle broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 21, 2024
1 parent 48509b3 commit 04e74c9
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 4 deletions.
43 changes: 39 additions & 4 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
import itertools
import logging
Expand Down Expand Up @@ -96,10 +97,44 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
await self.scheduler.broadcast(
msg=msg,
workers=list(shuffle.participating_workers),
)
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise P2PConsistencyError(
"Unexpected error encountered during barrier"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
raise P2PConsistencyError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise P2PConsistencyError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing tcp.connect timeout may help."""
)

def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
Expand Down
58 changes: 58 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dask
from dask.utils import key_split

from distributed.comm.core import Comm
from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.worker import Status
Expand Down Expand Up @@ -3026,3 +3027,60 @@ async def test_workers_do_not_spam_get_requests(c, s, a, b):
await assert_worker_cleanup(a)
await assert_worker_cleanup(b)
await assert_scheduler_cleanup(s)


class BarrierInputsDoneOSErrorPlugin(ShuffleWorkerPlugin):
def __init__(
self,
failures: dict[str, tuple[int, type]] | None = None,
):
self.failures = failures or {}
super().__init__()

async def shuffle_inputs_done(self, comm: Comm, *args: Any, **kwargs: Any) -> None: # type: ignore
if self.worker.address in self.failures:
nfailures, exc_type = self.failures[self.worker.address]
if nfailures > 0:
nfailures -= 1
self.failures[self.worker.address] = nfailures, exc_type
if issubclass(exc_type, OSError):
# Aborting the Comm object triggers a different path in
# error handling that resembles a genuine connection failure
# like a timeout while an exception that is being raised by
# the handler will be serialized and sent to the scheduler
comm.abort()
raise exc_type # type: ignore
return await super().shuffle_inputs_done(*args, **kwargs)


@pytest.mark.parametrize(
"failures, expected_exc",
[
({}, None),
({0: (1, OSError)}, None),
({0: (1, RuntimeError)}, P2PConsistencyError),
({0: (1, OSError), 1: (1, OSError)}, None),
({0: (1, OSError), 1: (1, RuntimeError)}, P2PConsistencyError),
({0: (5, OSError)}, P2PConsistencyError),
({0: (5, OSError), 1: (1, OSError)}, P2PConsistencyError),
],
)
@gen_cluster(client=True)
async def test_flaky_broadcast(c, s, a, b, failures, expected_exc):
names_to_address = {w.name: w.address for w in [a, b]}
failures = {names_to_address[name]: failures for name, failures in failures.items()}
plugin = BarrierInputsDoneOSErrorPlugin(failures)
await c.register_plugin(plugin, name="shuffle")

if expected_exc:
ctx = pytest.raises(expected_exc)
else:
ctx = contextlib.nullcontext()
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
ddf = dd.from_pandas(pdf, npartitions=2)
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = ddf.shuffle("x")

res = c.compute(shuffled)
with ctx:
await c.gather(res)

0 comments on commit 04e74c9

Please sign in to comment.