Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 21, 2024
1 parent f5a6ace commit b81591b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 9 deletions.
20 changes: 16 additions & 4 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,20 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
on_error="return",
)
before = len(workers)
workers = [w for w, r in res.items() if r != "OK"]
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise P2PConsistencyError(
"Unexpected error encountered during barrier"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast, retrying.",
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
Expand All @@ -119,9 +129,11 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress > 3:
if no_progress >= 3:
raise P2PConsistencyError(
f"Broadcast not making progress for {shuffle}"
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing tcp.connect timeout may help."""
)

def restrict_task(
Expand Down
7 changes: 2 additions & 5 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, overload

import dask
from dask.context import thread_state
Expand Down Expand Up @@ -327,16 +327,13 @@ async def shuffle_receive(
except P2PConsistencyError as e:
return error_message(e)

async def shuffle_inputs_done(
self, shuffle_id: ShuffleId, run_id: int
) -> Literal["OK"]:
async def shuffle_inputs_done(self, shuffle_id: ShuffleId, run_id: int) -> None:
"""
Handler: Inform the extension that all input partitions have been handed off to extensions.
Using an unknown ``shuffle_id`` is an error.
"""
shuffle_run = await self._get_shuffle_run(shuffle_id, run_id)
await shuffle_run.inputs_done()
return "OK"

def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None:
"""Fails the shuffle run with the message as exception and triggers cleanup.
Expand Down
58 changes: 58 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dask
from dask.utils import key_split

from distributed.comm.core import Comm
from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.worker import Status
Expand Down Expand Up @@ -2955,3 +2956,60 @@ async def test_dont_downscale_participating_workers(c, s, a, b):

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2


class BarrierInputsDoneOSErrorPlugin(ShuffleWorkerPlugin):
def __init__(
self,
failures: dict[str, tuple[int, type]] | None = None,
):
self.failures = failures or {}
super().__init__()

async def shuffle_inputs_done(self, comm: Comm, *args: Any, **kwargs: Any) -> None: # type: ignore
if self.worker.address in self.failures:
nfailures, exc_type = self.failures[self.worker.address]
if nfailures > 0:
nfailures -= 1
self.failures[self.worker.address] = nfailures, exc_type
if issubclass(exc_type, OSError):
# Aborting the Comm object triggers a different path in
# error handling that resembles a genuine connection failure
# like a timeout while an exception that is being raised by
# the handler will be serialized and sent to the scheduler
comm.abort()
raise exc_type # type: ignore
return await super().shuffle_inputs_done(*args, **kwargs)


@pytest.mark.parametrize(
"failures, expected_exc",
[
({}, None),
({0: (1, OSError)}, None),
({0: (1, RuntimeError)}, P2PConsistencyError),
({0: (1, OSError), 1: (1, OSError)}, None),
({0: (1, OSError), 1: (1, RuntimeError)}, P2PConsistencyError),
({0: (5, OSError)}, P2PConsistencyError),
({0: (5, OSError), 1: (1, OSError)}, P2PConsistencyError),
],
)
@gen_cluster(client=True)
async def test_flaky_broadcast(c, s, a, b, failures, expected_exc):
names_to_address = {w.name: w.address for w in [a, b]}
failures = {names_to_address[name]: failures for name, failures in failures.items()}
plugin = BarrierInputsDoneOSErrorPlugin(failures)
await c.register_plugin(plugin, name="shuffle")

if expected_exc:
ctx = pytest.raises(expected_exc)
else:
ctx = contextlib.nullcontext()
pdf = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})
ddf = dd.from_pandas(pdf, npartitions=2)
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = ddf.shuffle("x")

res = c.compute(shuffled)
with ctx:
await c.gather(res)

0 comments on commit b81591b

Please sign in to comment.