45 fail, 110 skipped, 3 975 pass in 10h 10m 18s
25 files 25 suites 10h 10m 18s ⏱️
4 130 tests 3 975 ✅ 110 💤 45 ❌
47 692 runs 45 131 ✅ 2 121 💤 440 ❌
Results for commit bb25546.
Annotations
Check warning on line 0 in distributed.comm.tests.test_comms
github-actions / Unit Test Results
3 out of 12 runs failed: test_tls_comm_closed_implicit[tornado] (distributed.comm.tests.test_comms)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
ssl.SSLError: [SYS] unknown error (_ssl.c:2570)
tcp = <module 'distributed.comm.tcp' from '/home/runner/work/distributed/distributed/distributed/comm/tcp.py'>
@gen_test()
async def test_tls_comm_closed_implicit(tcp):
> await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs)
distributed/comm/tests/test_comms.py:777:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/comm/tests/test_comms.py:763: in check_comm_closed_implicit
await comm.read()
distributed/comm/tcp.py:225: in read
frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:422: in read_bytes
self._try_inline_read()
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:836: in _try_inline_read
pos = self._read_to_buffer_loop()
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:750: in _read_to_buffer_loop
if self._read_to_buffer() == 0:
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:861: in _read_to_buffer
bytes_read = self.read_from_fd(buf)
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:1552: in read_from_fd
return self.socket.recv_into(buf, len(buf))
../../../miniconda3/envs/dask-distributed/lib/python3.12/ssl.py:1251: in recv_into
return self.read(nbytes, buffer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ssl.SSLSocket [closed] fd=-1, family=2, type=1, proto=0>, len = 65536
buffer = bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x...0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
if self._sslobj is None:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
if buffer is not None:
> return self._sslobj.read(len, buffer)
E ssl.SSLError: [SYS] unknown error (_ssl.c:2570)
../../../miniconda3/envs/dask-distributed/lib/python3.12/ssl.py:1103: SSLError
Check warning on line 0 in distributed.diagnostics.tests.test_task_stream
github-actions / Unit Test Results
1 out of 12 runs failed: test_get_task_stream_save (distributed.diagnostics.tests.test_task_stream)
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 4s]
Raw output
assert 'inc' in '<!DOCTYPE html>\n<html lang="en">\n <head>\n <meta charset="utf-8">\n <title>Dask Task Stream</title>\n <style>\n html, body {\n box-sizing: border-box;\n display: flow-root;\n height: 100%;\n margin: 0;\n padding: 0;\n }\n </style>\n <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-3.6.1.min.js"></script>\n <script type="text/javascript">\n Bokeh.set_log_level("info");\n </script>\n </head>\n <body>\n <div id="f0a86e3b-74df-4f2c-a395-e9669ab90e03" data-root-id="p23085" style="display: contents;"></div>\n \n <script type="application/json" id="ec151569-8f8d-4c98-ac24-bcdf0ce11ac5">\n {"52e17bdd-6795-4092-8917-677256f3eddd":{"version":"3.6.1","title":"Bokeh Application","roots":[{"type":"object","name":"Figure","id":"p23085","attributes":{"name":"task_stream","sizing_mode":"stretch_both","x_range":{"type":"object","name":"DataRange1d","id":"p23083","attributes":{"range_padding":0}},"y_range":{"type":"object","name":"DataRange1d","id":"p23084","attributes":{"range_padding":0}},"x_scale":{"type":"object","name":"LinearScale","id":"p23095"},"y_scale":{"type":"object",...bcdf0ce11ac5\').textContent;\n const render_items = [{"docid":"52e17bdd-6795-4092-8917-677256f3eddd","roots":{"p23085":"f0a86e3b-74df-4f2c-a395-e9669ab90e03"},"root_ids":["p23085"]}];\n root.Bokeh.embed.embed_items(docs_json, render_items);\n }\n if (root.Bokeh !== undefined) {\n embed_document(root);\n } else {\n let attempts = 0;\n const timer = setInterval(function(root) {\n if (root.Bokeh !== undefined) {\n clearInterval(timer);\n embed_document(root);\n } else {\n attempts++;\n if (attempts > 100) {\n clearInterval(timer);\n console.log("Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing");\n }\n }\n }, 10, root)\n }\n })(window);\n });\n };\n if (document.readyState != "loading") fn();\n else document.addEventListener("DOMContentLoaded", fn);\n })();\n </script>\n </body>\n</html>'
client = <Client: 'tcp://127.0.0.1:53059' processes=2 threads=2, memory=32.00 GiB>
tmp_path = WindowsPath('C:/Users/runneradmin/AppData/Local/Temp/pytest-of-runneradmin/pytest-0/test_get_task_stream_save0')
def test_get_task_stream_save(client, tmp_path):
bkm = pytest.importorskip("bokeh.models")
tmpdir = str(tmp_path)
fn = os.path.join(tmpdir, "foo.html")
with get_task_stream(plot="save", filename=fn) as ts:
wait(client.map(inc, range(10)))
with open(fn) as f:
data = f.read()
> assert "inc" in data
E assert 'inc' in '<!DOCTYPE html>\n<html lang="en">\n <head>\n <meta charset="utf-8">\n <title>Dask Task Stream</title>\n <style>\n html, body {\n box-sizing: border-box;\n display: flow-root;\n height: 100%;\n margin: 0;\n padding: 0;\n }\n </style>\n <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-3.6.1.min.js"></script>\n <script type="text/javascript">\n Bokeh.set_log_level("info");\n </script>\n </head>\n <body>\n <div id="f0a86e3b-74df-4f2c-a395-e9669ab90e03" data-root-id="p23085" style="display: contents;"></div>\n \n <script type="application/json" id="ec151569-8f8d-4c98-ac24-bcdf0ce11ac5">\n {"52e17bdd-6795-4092-8917-677256f3eddd":{"version":"3.6.1","title":"Bokeh Application","roots":[{"type":"object","name":"Figure","id":"p23085","attributes":{"name":"task_stream","sizing_mode":"stretch_both","x_range":{"type":"object","name":"DataRange1d","id":"p23083","attributes":{"range_padding":0}},"y_range":{"type":"object","name":"DataRange1d","id":"p23084","attributes":{"range_padding":0}},"x_scale":{"type":"object","name":"LinearScale","id":"p23095"},"y_scale":{"type":"object",...bcdf0ce11ac5\').textContent;\n const render_items = [{"docid":"52e17bdd-6795-4092-8917-677256f3eddd","roots":{"p23085":"f0a86e3b-74df-4f2c-a395-e9669ab90e03"},"root_ids":["p23085"]}];\n root.Bokeh.embed.embed_items(docs_json, render_items);\n }\n if (root.Bokeh !== undefined) {\n embed_document(root);\n } else {\n let attempts = 0;\n const timer = setInterval(function(root) {\n if (root.Bokeh !== undefined) {\n clearInterval(timer);\n embed_document(root);\n } else {\n attempts++;\n if (attempts > 100) {\n clearInterval(timer);\n console.log("Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing");\n }\n }\n }, 10, root)\n }\n })(window);\n });\n };\n if (document.readyState != "loading") fn();\n else document.addEventListener("DOMContentLoaded", fn);\n })();\n </script>\n </body>\n</html>'
distributed\diagnostics\tests\test_task_stream.py:159: AssertionError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[p2p-tasks] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 103353343764f68a051e1ccb8959794b failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '103353343764f68a051e1ccb8959794b'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and… shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43009', workers: 0, cores: 0, tasks: 0>
config_value = 'tasks', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:35513', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42035', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.98768471, 0.17582926, 0.92605057, 0.16072492, 0.16546501,
0.28120577, 0.43820435, 0.02956598, 0.3916..., 0.98722643, 0.53573731, 0.70246876, 0.39640673,
0.76411575, 0.65548619, 0.35407073, 0.19018565, 0.25111724]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 103353343764f68a051e1ccb8959794b failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[p2p-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7139bca66b75640feaa6b48f75426176 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '7139bca66b75640feaa6b48f75426176'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and… shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45201', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:34133', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39907', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.67263164, 0.99099891, 0.23391163, 0.0527319 , 0.37175087,
0.10399512, 0.30617677, 0.6728578 , 0.3489..., 0.36582142, 0.29896918, 0.57915701, 0.61579285,
0.63909114, 0.30842371, 0.44911196, 0.22216101, 0.92145394]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 7139bca66b75640feaa6b48f75426176 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[p2p-None] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7ab19aa6fa0aa500648d26d50ba3cc34 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '7ab19aa6fa0aa500648d26d50ba3cc34'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and… shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45609', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:37985', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39669', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.67946287, 0.49174119, 0.57030695, 0.28143769, 0.93976304,
0.30630708, 0.070039 , 0.95549053, 0.1707..., 0.27496663, 0.06050454, 0.44717105, 0.51656972,
0.10041995, 0.26876562, 0.52153445, 0.59627235, 0.26186335]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 7ab19aa6fa0aa500648d26d50ba3cc34 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[None-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P f2e1cf610172a1187382e9e34d16198d failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'f2e1cf610172a1187382e9e34d16198d'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and… shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36603', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = None
ws = (<Worker 'tcp://127.0.0.1:46485', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40951', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.76190721, 0.18421237, 0.32989737, 0.72996263, 0.7042225 ,
0.62118023, 0.80522491, 0.4608107 , 0.0737..., 0.61995515, 0.14544883, 0.11477718, 0.08683625,
0.15604437, 0.68832368, 0.58877247, 0.85997467, 0.76332141]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P f2e1cf610172a1187382e9e34d16198d failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_configuration[None-None] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P e0eda9e4c6daa20db3080bf5f59ffb75 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'e0eda9e4c6daa20db3080bf5f59ffb75'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:41971', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = None
ws = (<Worker 'tcp://127.0.0.1:40879', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:46379', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[3.12737881e-01, 4.60867855e-01, 9.73577295e-01, 3.15562528e-01,
9.61726398e-01, 7.24117584e-01, 6.9431...e-01,
4.99845514e-01, 7.43740630e-01, 5.46473780e-01, 4.42117179e-02,
1.22833879e-01, 3.27800041e-01]])
@pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
@pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
@gen_cluster(client=True)
async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
"""Try rechunking a random 1d matrix
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_1d
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(10, 1))
new = ((1,) * 10, (10,))
config = {"array.rechunk.method": config_value} if config_value is not None else {}
with dask.config.set(config):
x2 = rechunk(x, chunks=new, method=keyword)
expected_algorithm = keyword if keyword is not None else config_value
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
elif expected_algorithm == "tasks":
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
# Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
else:
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:210:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P e0eda9e4c6daa20db3080bf5f59ffb75 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_heuristic[new0-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 90af7b3bcf0d14cd2d09aae30fd664cf failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '90af7b3bcf0d14cd2d09aae30fd664cf'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…ate_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:46527', workers: 0, cores: 0, tasks: 0>
a = array([[0.09938915, 0.49322808, 0.60598825, ..., 0.4802486 , 0.23827154,
0.81719055],
[0.5558567 , 0.27...52,
0.69894025],
[0.25175323, 0.82966441, 0.95636455, ..., 0.76615423, 0.78213394,
0.65567332]])
b = <Worker 'tcp://127.0.0.1:38815', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((1, 1, 1, 1, 1, 1, ...), (100,)), expected_algorithm = 'p2p'
@pytest.mark.parametrize(
["new", "expected_algorithm"],
[
# All-to-all rechunking defaults to P2P
(((1,) * 100, (100,)), "p2p"),
# Localized rechunking defaults to tasks
(((50, 50), (2,) * 50), "tasks"),
# Less local rechunking first defaults to tasks,
(((25, 25, 25, 25), (4,) * 25), "tasks"),
# then switches to p2p
(((10,) * 10, (10,) * 10), "p2p"),
],
)
@gen_cluster(client=True)
async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
x = da.from_array(a, chunks=(100, 1))
x2 = rechunk(x, chunks=new)
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
else:
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 90af7b3bcf0d14cd2d09aae30fd664cf failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_heuristic[new3-p2p] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 8bebd7a505fd78dd1b1dcb229dea41a8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '8bebd7a505fd78dd1b1dcb229dea41a8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38873', workers: 0, cores: 0, tasks: 0>
a = array([[0.17099363, 0.84982634, 0.68269121, ..., 0.58142871, 0.20141963,
0.70392131],
[0.93343985, 0.96...35,
0.77540941],
[0.75007872, 0.23444326, 0.77904639, ..., 0.47615695, 0.47363149,
0.06353038]])
b = <Worker 'tcp://127.0.0.1:35891', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((10, 10, 10, 10, 10, 10, ...), (10, 10, 10, 10, 10, 10, ...))
expected_algorithm = 'p2p'
@pytest.mark.parametrize(
["new", "expected_algorithm"],
[
# All-to-all rechunking defaults to P2P
(((1,) * 100, (100,)), "p2p"),
# Localized rechunking defaults to tasks
(((50, 50), (2,) * 50), "tasks"),
# Less local rechunking first defaults to tasks,
(((25, 25, 25, 25), (4,) * 25), "tasks"),
# then switches to p2p
(((10,) * 10, (10,) * 10), "p2p"),
],
)
@gen_cluster(client=True)
async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
x = da.from_array(a, chunks=(100, 1))
x2 = rechunk(x, chunks=new)
if expected_algorithm == "p2p":
assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
else:
assert not any(
key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
)
assert x2.chunks == new
> assert np.all(await c.compute(x2) == a)
distributed/shuffle/tests/test_rechunk.py:239:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 8bebd7a505fd78dd1b1dcb229dea41a8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_cull_p2p_rechunk_independent_partitions (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
assert 57 < (228 / 4)
+ where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n)
+ where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
+ and 228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n)
+ where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35601', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:34703', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:44387', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[[0.599171 , 0.17502798, 0.88209158, 0.21115568, 0.46142343,
0.03521772, 0.44125943, 0.5110587 , 0.64...0.8850746 , 0.48018381, 0.09855244, 0.71004424,
0.51442559, 0.78168598, 0.52234215, 0.70874625, 0.70225798]]])
x = dask.array<array, shape=(10, 10, 10), dtype=float64, chunksize=(1, 5, 1), chunktype=numpy.ndarray>
new = (5, 1, -1)
@gen_cluster(client=True)
async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws):
a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10))
x = da.from_array(a, chunks=(1, 5, 1))
new = (5, 1, -1)
rechunked = rechunk(x, chunks=new, method="p2p")
(dsk,) = dask.optimize(rechunked)
culled = rechunked[:5, :2]
(dsk_culled,) = dask.optimize(culled)
# The culled graph requires only 1/2 of the input tasks
n_inputs = len(
[1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")]
)
n_culled_inputs = len(
[
1
for key in dsk_culled.dask.get_all_dependencies()
if key[0].startswith("array-")
]
)
assert n_culled_inputs == n_inputs / 4
# The culled graph should also have less than 1/4 the tasks
> assert len(dsk_culled.dask) < len(dsk.dask) / 4
E assert 57 < (228 / 4)
E + where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n)
E + where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
E + and 228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n)
E + where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
distributed/shuffle/tests/test_rechunk.py:265: AssertionError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_expand (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 16159121174d0a3cf3a9d8fa1f4d944f failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '16159121174d0a3cf3a9d8fa1f4d944f'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…ib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
> yield
distributed/shuffle/_core.py:523:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39835', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:45529', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:46747', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.03029662, 0.67497307, 0.54696325, 0.5664905 , 0.27160949,
0.56769018, 0.65528007, 0.42471041, 0.7749..., 0.04748013, 0.53053804, 0.17881593, 0.05720855,
0.00615196, 0.73041112, 0.96641799, 0.98946603, 0.54375914]])
x = dask.array<array, shape=(10, 10), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
y = dask.array<rechunk-p2p, shape=(10, 10), dtype=float64, chunksize=(3, 3), chunktype=numpy.ndarray>
@gen_cluster(client=True)
async def test_rechunk_expand(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_expand
"""
a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
x = da.from_array(a, chunks=(5, 5))
y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
> assert np.all(await c.compute(y) == a)
distributed/shuffle/tests/test_rechunk.py:377:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 16159121174d0a3cf3a9d8fa1f4d944f failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
All 11 runs failed: test_rechunk_expand2 (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P bc0b6e3c445f0bbb5a37bf95be33b322 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'bc0b6e3c445f0bbb5a37bf95be33b322'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…le_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
> yield
distributed/shuffle/_core.py:523:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43843', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:37269', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:33729', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = 3, b = 2
orig = array([[0.42729343, 0.06459062, 0.34592701],
[0.86826575, 0.07615254, 0.2244809 ],
[0.0676913 , 0.43340265, 0.93535141]])
@gen_cluster(client=True)
async def test_rechunk_expand2(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_expand2
"""
(a, b) = (3, 2)
orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
for off, off2 in product(range(1, a - 1), range(1, a - 1)):
old = ((a - off, off),) * b
x = da.from_array(orig, chunks=old)
new = ((a - off2, off2),) * b
assert np.all(await c.compute(x.rechunk(chunks=new, method="p2p")) == orig)
if a - off - off2 > 0:
new = ((off, a - off2 - off, off2),) * b
> y = await c.compute(x.rechunk(chunks=new, method="p2p"))
distributed/shuffle/tests/test_rechunk.py:396:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P bc0b6e3c445f0bbb5a37bf95be33b322 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_unknown_from_pandas (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 427d2d9ead2057e7af39f331aad1826a failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '427d2d9ead2057e7af39f331aad1826a'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…le/_worker_plugin.py:411: in get_or_create_shuffle
return sync(
distributed/utils.py:439: in sync
raise error
distributed/utils.py:413: in f
result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33963', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:41295', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39915', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
pd = <module 'pandas' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/pandas/__init__.py'>
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
arr = array([[-1.10870194e-01, 1.24996567e+00, 9.07049911e-02,
1.24946836e+00, -5.82149888e-01, 1.31636619e+00,
... 9.54350494e-02, -1.54228652e-01,
1.41285044e+00, -3.73567619e-01, 2.63043451e-01,
-1.00381024e+00]])
@gen_cluster(client=True)
async def test_rechunk_unknown_from_pandas(c, s, *ws):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas
"""
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
arr = np.random.default_rng().standard_normal((50, 10))
x = dd.from_pandas(pd.DataFrame(arr), 2).values
result = x.rechunk((None, (5, 5)), method="p2p")
assert np.isnan(x.chunks[0]).all()
assert np.isnan(result.chunks[0]).all()
assert result.chunks[1] == (5, 5)
expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk(
(None, (5, 5)), method="p2p"
)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:706:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 427d2d9ead2057e7af39f331aad1826a failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x0-chunks0] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'd9513d99181485ac4237b30c4b28d743'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…rker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35113', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:40495', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40935', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x1-chunks1] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 0e8a9f04ceb745bcef4ba25a6e0f2f79 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '0e8a9f04ceb745bcef4ba25a6e0f2f79'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35473', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:44921', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40681', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 0e8a9f04ceb745bcef4ba25a6e0f2f79 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x2-chunks2] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: 'd9513d99181485ac4237b30c4b28d743'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36599', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:44699', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39473', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x3-chunks3] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '871d68ad5c66483062239bd5af22c9b8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…rker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33749', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:34851', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:33885', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x4-chunks4] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '871d68ad5c66483062239bd5af22c9b8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38033', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:44575', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:36277', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x5-chunks5] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '871d68ad5c66483062239bd5af22c9b8'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40915', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:45271', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:32875', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x6-chunks6] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '01992e7a5ca3de072c0c2984a1548c6e'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…rker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43175', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:36981', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42401', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x7-chunks7] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '01992e7a5ca3de072c0c2984a1548c6e'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45757', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:40155', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:45267', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x8-chunks8] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '01992e7a5ca3de072c0c2984a1548c6e'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43977', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:44945', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:35293', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x9-chunks9] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '9f169a5772161f491fce47676e173277'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…orker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40785', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:39757', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41009', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x10-chunks10] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '9f169a5772161f491fce47676e173277'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…/_worker_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33929', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:35765', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:34181', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError
Check warning on line 0 in distributed.shuffle.tests.test_rechunk
github-actions / Unit Test Results
10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x11-chunks11] (distributed.shuffle.tests.test_rechunk)
artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
assert isinstance(barrier_task_spec, P2PBarrierTask)
return barrier_task_spec.spec
def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
# FIXME: The current implementation relies on the barrier task to be
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(shuffle_id)
self._raise_if_task_not_processing(key)
spec = self._retrieve_spec(shuffle_id)
worker_for = self._calculate_worker_for(spec)
self._ensure_output_tasks_are_non_rootish(spec)
state = spec.create_new_run(
worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
)
self.active_shuffles[shuffle_id] = state
self._shuffles[shuffle_id].add(state)
state.participating_workers.add(worker)
logger.warning(
"Shuffle %s initialized by task %r executed on worker %s",
shuffle_id,
key,
worker,
)
return state.run_spec
def get_or_create(
self,
shuffle_id: ShuffleId,
key: Key,
worker: str,
) -> RunSpecMessage | ErrorMessage:
try:
> run_spec = self._get(shuffle_id, worker)
distributed/shuffle/_scheduler_plugin.py:229:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
> state = self.active_shuffles[id]
E KeyError: '9f169a5772161f491fce47676e173277'
distributed/shuffle/_scheduler_plugin.py:190: KeyError
During handling of the above exception, another exception occurred:
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and…_plugin.py:222: in _refresh
result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import asyncio
import contextlib
import itertools
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from dask.typing import Key
from distributed.core import ErrorMessage, OKMessage, error_message
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._core import (
P2PBarrierTask,
RunSpecMessage,
SchedulerShuffleState,
ShuffleId,
ShuffleRunSpec,
ShuffleSpec,
barrier_key,
id_from_key,
)
from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.utils import log_errors
if TYPE_CHECKING:
from distributed.scheduler import (
Recs,
Scheduler,
TaskState,
TaskStateState,
WorkerState,
)
logger = logging.getLogger(__name__)
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""
scheduler: Scheduler
active_shuffles: dict[ShuffleId, SchedulerShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
_shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
_archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
_shift_counter: itertools.count[int]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers.update(
{
"shuffle_barrier": self.barrier,
"shuffle_get": self.get,
"shuffle_get_or_create": self.get_or_create,
"shuffle_restrict_task": self.restrict_task,
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)
self._shift_counter = itertools.count()
async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
await self.scheduler.register_worker_plugin(
None, dumps(worker_plugin), name="shuffle", idempotent=False
)
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)
async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
if not consistent:
logger.warning(
"Shuffle %s restarted due to data inconsistency during barrier",
shuffle.id,
)
return self._restart_shuffle(
shuffle.id,
self.scheduler,
stimulus_id=f"p2p-barrier-inconsistent-{time()}",
)
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
workers = list(shuffle.participating_workers)
no_progress = 0
while workers:
res = await self.scheduler.broadcast(
msg=msg,
workers=workers,
on_error="return",
)
before = len(workers)
workers = []
for w, r in res.items():
if r is None:
continue
if isinstance(r, OSError):
workers.append(w)
else:
raise RuntimeError(
f"Unexpected error encountered during P2P barrier: {r!r}"
)
workers = [w for w, r in res.items() if r is not None]
if workers:
logger.warning(
"Failure during broadcast of %s, retrying.",
shuffle.id,
)
if any(w not in self.scheduler.workers for w in workers):
if not shuffle.archived:
# If the shuffle is not yet archived, this could mean that the barrier task fails
# before the P2P restarting mechanism can kick in.
raise P2PIllegalStateError(
"Expected shuffle to be archived if participating worker is not known by scheduler"
)
raise RuntimeError(
f"Worker {workers} left during shuffle {shuffle}"
)
await asyncio.sleep(0.1)
if len(workers) == before:
no_progress += 1
if no_progress >= 3:
raise RuntimeError(
f"""Broadcast not making progress for {shuffle}.
Aborting. This is possibly due to overloaded
workers. Increasing config
`distributed.comm.timeouts.connect` timeout may
help."""
)
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, worker: str
) -> OKMessage | ErrorMessage:
try:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
raise P2PConsistencyError(
f"Request stale, expected {run_id=} for {shuffle}"
)
elif shuffle.run_id < run_id:
raise P2PConsistencyError(
f"Request invalid, expected {run_id=} for {shuffle}"
)
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
def heartbeat(self, ws: WorkerState, data: dict) -> None:
for shuffle_id, d in data.items():
if shuffle_id in self.shuffle_ids():
self.heartbeats[shuffle_id][ws.address].update(d)
def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
try:
try:
run_spec = self._get(id, worker)
return {"status": "OK", "run_spec": ToPickle(run_spec)}
except KeyError as e:
raise P2PConsistencyError(
f"No active shuffle with {id=!r} found"
) from e
except P2PConsistencyError as e:
return error_message(e)
def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
if worker not in self.scheduler.workers:
# This should never happen
raise P2PConsistencyError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.run_spec
def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
> assert isinstance(barrier_task_spec, P2PBarrierTask)
E AssertionError
distributed/shuffle/_scheduler_plugin.py:196: AssertionError
The above exception was the direct cause of the following exception:
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39469', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:43215', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41579', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
@pytest.mark.parametrize(
"x, chunks",
[
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
(da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
(da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
(da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
(da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
(da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
(da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
],
)
@gen_cluster(client=True)
async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
"""
See Also
--------
dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
"""
pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
y = dd.from_array(x).values
result = y.rechunk(chunks, method="p2p")
expected = x.rechunk(chunks, method="p2p")
assert_chunks_match(result.chunks, expected.chunks)
> assert_eq(await c.compute(result), await c.compute(expected))
distributed/shuffle/tests/test_rechunk.py:757:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
distributed/client.py:410: in _result
raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
from __future__ import annotations
import abc
import asyncio
import contextlib
import itertools
import pickle
import time
from collections.abc import (
Callable,
Coroutine,
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
from tornado.ioloop import IOLoop
import dask.config
from dask._task_spec import Task
from dask.core import flatten
from dask.typing import Key
from dask.utils import parse_bytes, parse_timedelta
from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter, thread_time
from distributed.protocol import to_serialize
from distributed.protocol.serialize import ToPickle
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import (
P2PConsistencyError,
P2POutOfDiskError,
ShuffleClosedError,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import ParamSpec, TypeAlias
_P = ParamSpec("_P")
# circular dependencies
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
ShuffleId = NewType("ShuffleId", str)
NDIndex: TypeAlias = tuple[int, ...]
_T_partition_id = TypeVar("_T_partition_id")
_T_partition_type = TypeVar("_T_partition_type")
_T = TypeVar("_T")
class RunSpecMessage(OKMessage):
run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
id: ShuffleId
run_id: int
span_id: str | None
local_address: str
executor: ThreadPoolExecutor
rpc: Callable[[str], PooledRPCCall]
digest_metric: Callable[[Hashable, float], None]
scheduler: PooledRPCCall
closed: bool
_disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
_comm_buffer: CommShardsBuffer
received: set[_T_partition_id]
total_recvd: int
start_time: float
_exception: Exception | None
_closed_event: asyncio.Event
_loop: IOLoop
RETRY_COUNT: int
RETRY_DELAY_MIN: float
RETRY_DELAY_MAX: float
def __init__(
self,
id: ShuffleId,
run_id: int,
span_id: str | None,
local_address: str,
directory: str,
executor: ThreadPoolExecutor,
rpc: Callable[[str], PooledRPCCall],
digest_metric: Callable[[Hashable, float], None],
scheduler: PooledRPCCall,
memory_limiter_disk: ResourceLimiter,
memory_limiter_comms: ResourceLimiter,
disk: bool,
loop: IOLoop,
):
self.id = id
self.run_id = run_id
self.span_id = span_id
self.local_address = local_address
self.executor = executor
self.rpc = rpc
self.digest_metric = digest_metric
self.scheduler = scheduler
self.closed = False
# Initialize buffers and start background tasks
# Don't log metrics issued by the background tasks onto the dask task that
# spawned this object
with context_meter.clear_callbacks():
with self._capture_metrics("background-disk"):
if disk:
self._disk_buffer = DiskShardsBuffer(
directory=directory,
read=self.read,
memory_limiter=memory_limiter_disk,
)
else:
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
with self._capture_metrics("background-comms"):
max_message_size = parse_bytes(
dask.config.get("distributed.p2p.comm.message-bytes-limit")
)
concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
self._comm_buffer = CommShardsBuffer(
send=self.send,
max_message_size=max_message_size,
memory_limiter=memory_limiter_comms,
concurrency_limit=concurrency_limit,
)
# TODO: reduce number of connections to number of workers
# MultiComm.max_connections = min(10, n_workers)
self.transferred = False
self.received = set()
self.total_recvd = 0
self.start_time = time.time()
self._exception = None
self._closed_event = asyncio.Event()
self._loop = loop
self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
self.RETRY_DELAY_MIN = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
)
self.RETRY_DELAY_MAX = parse_timedelta(
dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
def __hash__(self) -> int:
return self.run_id
@contextlib.contextmanager
def _capture_metrics(self, where: str) -> Iterator[None]:
"""Capture context_meter metrics as
{('p2p', <span id>, 'foreground|background...', label, unit): value}
**Note 1:** When the metric is not logged by a background task
(where='foreground'), this produces a duplicated metric under
{('execute', <span id>, <task prefix>, label, unit): value}
This is by design so that one can have a holistic view of the whole shuffle
process.
**Note 2:** We're immediately writing to Worker.digests.
We don't temporarily store metrics under ShuffleRun as we would lose those
recorded between the heartbeat and when the ShuffleRun object is deleted at the
end of a run.
"""
def callback(label: Hashable, value: float, unit: str) -> None:
if not isinstance(label, tuple):
label = (label,)
if isinstance(label[0], str) and label[0].startswith("p2p-"):
label = (label[0][len("p2p-") :], *label[1:])
name = ("p2p", self.span_id, where, *label, unit)
self.digest_metric(name, value)
with context_meter.add_callback(callback, allow_offload="background" in where):
yield
async def barrier(self, run_ids: Sequence[int]) -> int:
self.raise_if_closed()
consistent = all(run_id == self.run_id for run_id in run_ids)
# TODO: Consider broadcast pinging once when the shuffle starts to warm
# up the comm pool on scheduler side
await self.scheduler.shuffle_barrier(
id=self.id, run_id=self.run_id, consistent=consistent
)
return self.run_id
async def _send(
self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
self.raise_if_closed()
return await self.rpc(address).shuffle_receive(
data=to_serialize(shards),
shuffle_id=self.id,
run_id=self.run_id,
)
async def send(
self, address: str, shards: list[tuple[_T_partition_id, Any]]
) -> OKMessage | ErrorMessage:
if _mean_shard_size(shards) < 65536:
# Don't send buffers individually over the tcp comms.
# Instead, merge everything into an opaque bytes blob, send it all at once,
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
else:
shards_or_bytes = shards
def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
return self._send(address, shards_or_bytes)
return await retry(
_send,
count=self.RETRY_COUNT,
delay_min=self.RETRY_DELAY_MIN,
delay_max=self.RETRY_DELAY_MAX,
)
async def offload(
self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
self.raise_if_closed()
with context_meter.meter("offload"):
return await run_in_executor_with_context(
self.executor, func, *args, **kwargs
)
def heartbeat(self) -> dict[str, Any]:
comm_heartbeat = self._comm_buffer.heartbeat()
comm_heartbeat["read"] = self.total_recvd
return {
"disk": self._disk_buffer.heartbeat(),
"comm": comm_heartbeat,
"start": self.start_time,
}
async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
)
def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(f"{self} has already been closed")
async def inputs_done(self) -> None:
self.raise_if_closed()
self.transferred = True
await self._flush_comm()
try:
self._comm_buffer.raise_on_exception()
except Exception as e:
self._exception = e
raise
async def _flush_comm(self) -> None:
self.raise_if_closed()
await self._comm_buffer.flush()
async def flush_receive(self) -> None:
self.raise_if_closed()
await self._disk_buffer.flush()
async def close(self) -> None:
if self.closed: # pragma: no cover
await self._closed_event.wait()
return
self.closed = True
await self._comm_buffer.close()
await self._disk_buffer.close()
self._closed_event.set()
def fail(self, exception: Exception) -> None:
if not self.closed:
self._exception = exception
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))
async def receive(
self, data: list[tuple[_T_partition_id, Any]] | bytes
) -> OKMessage | ErrorMessage:
try:
if isinstance(data, bytes):
# Unpack opaque blob. See send()
data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
await self._receive(data)
return {"status": "OK"}
except P2PConsistencyError as e:
return error_message(e)
async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
assigned_worker = self._get_assigned_worker(i)
if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
assert result["status"] == "OK"
raise Reschedule()
@abc.abstractmethod
def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""
@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""
def add_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
self.validate_data(data)
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
):
shards = self._shard_partition(data, partition_id)
sync(self._loop, self._write_to_comm, shards)
return self.run_id
@abc.abstractmethod
def _shard_partition(
self, data: _T_partition_type, partition_id: _T_partition_id
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""
def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)
with (
# Log metrics both in the "execute" and in the "p2p" contexts
self._capture_metrics("foreground"),
context_meter.meter("p2p-get-output-noncpu"),
context_meter.meter("p2p-get-output-cpu", func=thread_time),
):
return self._get_output_partition(partition_id, key, **kwargs)
@abc.abstractmethod
def _get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""
@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
@abc.abstractmethod
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""
def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""
def get_worker_plugin() -> ShuffleWorkerPlugin:
from distributed import get_worker
try:
worker = get_worker()
except ValueError as e:
raise RuntimeError(
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
try:
return worker.plugins["shuffle"] # type: ignore
except KeyError as e:
raise RuntimeError(
f"The worker {worker.address} does not have a P2P shuffle plugin."
) from e
_BARRIER_PREFIX = "shuffle-barrier-"
def barrier_key(shuffle_id: ShuffleId) -> str:
return _BARRIER_PREFIX + shuffle_id
def id_from_key(key: Key) -> ShuffleId | None:
if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
return None
return ShuffleId(key[len(_BARRIER_PREFIX) :])
class ShuffleType(Enum):
DATAFRAME = "DataFrameShuffle"
ARRAY_RECHUNK = "ArrayRechunk"
@dataclass(frozen=True)
class ShuffleRunSpec(Generic[_T_partition_id]):
run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
spec: ShuffleSpec
worker_for: dict[_T_partition_id, str]
span_id: str | None
@property
def id(self) -> ShuffleId:
return self.spec.id
@dataclass(frozen=True)
class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool
@property
@abc.abstractmethod
def output_partitions(self) -> Generator[_T_partition_id]:
"""Output partitions"""
@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""
def create_new_run(
self,
worker_for: dict[_T_partition_id, str],
span_id: str | None,
) -> SchedulerShuffleState:
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
participating_workers=set(worker_for.values()),
)
@abc.abstractmethod
def create_run_on_worker(
self,
run_id: int,
span_id: str | None,
worker_for: dict[_T_partition_id, str],
plugin: ShuffleWorkerPlugin,
) -> ShuffleRun:
"""Create the new shuffle run on the worker."""
@dataclass(eq=False)
class SchedulerShuffleState(Generic[_T_partition_id]):
run_spec: ShuffleRunSpec
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)
_failed: bool = False
@property
def id(self) -> ShuffleId:
return self.run_spec.id
@property
def run_id(self) -> int:
return self.run_spec.run_id
@property
def archived(self) -> bool:
return self._archived_by is not None
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
def __hash__(self) -> int:
return hash(self.run_id)
@contextlib.contextmanager
def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
try:
yield
except ShuffleClosedError:
raise Reschedule()
except P2PConsistencyError:
raise
except P2POutOfDiskError:
raise
except Exception as e:
> raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
distributed/shuffle/_core.py:531: RuntimeError