Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retry to shuffle broadcast #8900

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
import itertools
import logging
Expand Down Expand Up @@ -96,10 +97,46 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
await self.scheduler.broadcast(
msg=msg,
workers=list(shuffle.participating_workers),
)
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)

def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
Expand Down
59 changes: 59 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dask
from dask.utils import key_split

from distributed.comm.core import Comm
from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.worker import Status
Expand Down Expand Up @@ -3034,3 +3035,61 @@ async def test_workers_do_not_spam_get_requests(c, s, a, b):
await assert_worker_cleanup(a)
await assert_worker_cleanup(b)
await assert_scheduler_cleanup(s)


class BarrierInputsDoneOSErrorPlugin(ShuffleWorkerPlugin):
def __init__(
self,
failures: dict[str, tuple[int, type]] | None = None,
):
self.failures = failures or {}
super().__init__()

async def shuffle_inputs_done(self, comm: Comm, *args: Any, **kwargs: Any) -> None: # type: ignore
if self.worker.address in self.failures:
nfailures, exc_type = self.failures[self.worker.address]
if nfailures > 0:
nfailures -= 1
self.failures[self.worker.address] = nfailures, exc_type
if issubclass(exc_type, OSError):
# Aborting the Comm object triggers a different path in
# error handling that resembles a genuine connection failure
# like a timeout while an exception that is being raised by
# the handler will be serialized and sent to the scheduler
comm.abort()
raise exc_type # type: ignore
return await super().shuffle_inputs_done(*args, **kwargs)


@pytest.mark.parametrize(
"failures, expected_exc",
[
({}, None),
({0: (1, OSError)}, None),
({0: (1, RuntimeError)}, RuntimeError),
({0: (1, OSError), 1: (1, OSError)}, None),
({0: (1, OSError), 1: (1, RuntimeError)}, RuntimeError),
({0: (5, OSError)}, RuntimeError),
({0: (5, OSError), 1: (1, OSError)}, RuntimeError),
],
)
@pytest.mark.slow
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All included this is about 2.5s so not incredibly slow. The backoff is just 0.1 after all. If that issue persists we can add configuration and an exponential backoff, etc. but I'd rather not increase complexity further (but rather would like to get rid of the broadcast instead)

@gen_cluster(client=True)
async def test_flaky_broadcast(c, s, a, b, failures, expected_exc):
names_to_address = {w.name: w.address for w in [a, b]}
failures = {names_to_address[name]: failures for name, failures in failures.items()}
plugin = BarrierInputsDoneOSErrorPlugin(failures)
await c.register_plugin(plugin, name="shuffle")

if expected_exc:
ctx = pytest.raises(expected_exc)
else:
ctx = contextlib.nullcontext()
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
ddf = dd.from_pandas(pdf, npartitions=2)
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = ddf.shuffle("x")

res = c.compute(shuffled)
with ctx:
await c.gather(res)
Loading