Skip to content
GitHub Actions / Unit Test Results failed Nov 19, 2024 in 0s

45 fail, 110 skipped, 3 975 pass in 10h 10m 18s

    25 files      25 suites   10h 10m 18s ⏱️
 4 130 tests  3 975 ✅   110 💤  45 ❌
47 692 runs  45 131 ✅ 2 121 💤 440 ❌

Results for commit bb25546.

Annotations

Check warning on line 0 in distributed.comm.tests.test_comms

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

3 out of 12 runs failed: test_tls_comm_closed_implicit[tornado] (distributed.comm.tests.test_comms)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
ssl.SSLError: [SYS] unknown error (_ssl.c:2570)
tcp = <module 'distributed.comm.tcp' from '/home/runner/work/distributed/distributed/distributed/comm/tcp.py'>

    @gen_test()
    async def test_tls_comm_closed_implicit(tcp):
>       await check_comm_closed_implicit("tls://127.0.0.1", **tls_kwargs)

distributed/comm/tests/test_comms.py:777: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/comm/tests/test_comms.py:763: in check_comm_closed_implicit
    await comm.read()
distributed/comm/tcp.py:225: in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:422: in read_bytes
    self._try_inline_read()
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:836: in _try_inline_read
    pos = self._read_to_buffer_loop()
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:750: in _read_to_buffer_loop
    if self._read_to_buffer() == 0:
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:861: in _read_to_buffer
    bytes_read = self.read_from_fd(buf)
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/iostream.py:1552: in read_from_fd
    return self.socket.recv_into(buf, len(buf))
../../../miniconda3/envs/dask-distributed/lib/python3.12/ssl.py:1251: in recv_into
    return self.read(nbytes, buffer)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <ssl.SSLSocket [closed] fd=-1, family=2, type=1, proto=0>, len = 65536
buffer = bytearray(b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x...0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')

    def read(self, len=1024, buffer=None):
        """Read up to LEN bytes and return them.
        Return zero-length string on EOF."""
    
        self._checkClosed()
        if self._sslobj is None:
            raise ValueError("Read on closed or unwrapped SSL socket.")
        try:
            if buffer is not None:
>               return self._sslobj.read(len, buffer)
E               ssl.SSLError: [SYS] unknown error (_ssl.c:2570)

../../../miniconda3/envs/dask-distributed/lib/python3.12/ssl.py:1103: SSLError

Check warning on line 0 in distributed.diagnostics.tests.test_task_stream

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

1 out of 12 runs failed: test_get_task_stream_save (distributed.diagnostics.tests.test_task_stream)

artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 4s]
Raw output
assert 'inc' in '<!DOCTYPE html>\n<html lang="en">\n  <head>\n    <meta charset="utf-8">\n    <title>Dask Task Stream</title>\n    <style>\n      html, body {\n        box-sizing: border-box;\n        display: flow-root;\n        height: 100%;\n        margin: 0;\n        padding: 0;\n      }\n    </style>\n    <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-3.6.1.min.js"></script>\n    <script type="text/javascript">\n        Bokeh.set_log_level("info");\n    </script>\n  </head>\n  <body>\n    <div id="f0a86e3b-74df-4f2c-a395-e9669ab90e03" data-root-id="p23085" style="display: contents;"></div>\n  \n    <script type="application/json" id="ec151569-8f8d-4c98-ac24-bcdf0ce11ac5">\n      {"52e17bdd-6795-4092-8917-677256f3eddd":{"version":"3.6.1","title":"Bokeh Application","roots":[{"type":"object","name":"Figure","id":"p23085","attributes":{"name":"task_stream","sizing_mode":"stretch_both","x_range":{"type":"object","name":"DataRange1d","id":"p23083","attributes":{"range_padding":0}},"y_range":{"type":"object","name":"DataRange1d","id":"p23084","attributes":{"range_padding":0}},"x_scale":{"type":"object","name":"LinearScale","id":"p23095"},"y_scale":{"type":"object",...bcdf0ce11ac5\').textContent;\n              const render_items = [{"docid":"52e17bdd-6795-4092-8917-677256f3eddd","roots":{"p23085":"f0a86e3b-74df-4f2c-a395-e9669ab90e03"},"root_ids":["p23085"]}];\n              root.Bokeh.embed.embed_items(docs_json, render_items);\n              }\n              if (root.Bokeh !== undefined) {\n                embed_document(root);\n              } else {\n                let attempts = 0;\n                const timer = setInterval(function(root) {\n                  if (root.Bokeh !== undefined) {\n                    clearInterval(timer);\n                    embed_document(root);\n                  } else {\n                    attempts++;\n                    if (attempts > 100) {\n                      clearInterval(timer);\n                      console.log("Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing");\n                    }\n                  }\n                }, 10, root)\n              }\n            })(window);\n          });\n        };\n        if (document.readyState != "loading") fn();\n        else document.addEventListener("DOMContentLoaded", fn);\n      })();\n    </script>\n  </body>\n</html>'
client = <Client: 'tcp://127.0.0.1:53059' processes=2 threads=2, memory=32.00 GiB>
tmp_path = WindowsPath('C:/Users/runneradmin/AppData/Local/Temp/pytest-of-runneradmin/pytest-0/test_get_task_stream_save0')

    def test_get_task_stream_save(client, tmp_path):
        bkm = pytest.importorskip("bokeh.models")
        tmpdir = str(tmp_path)
        fn = os.path.join(tmpdir, "foo.html")
    
        with get_task_stream(plot="save", filename=fn) as ts:
            wait(client.map(inc, range(10)))
        with open(fn) as f:
            data = f.read()
>       assert "inc" in data
E       assert 'inc' in '<!DOCTYPE html>\n<html lang="en">\n  <head>\n    <meta charset="utf-8">\n    <title>Dask Task Stream</title>\n    <style>\n      html, body {\n        box-sizing: border-box;\n        display: flow-root;\n        height: 100%;\n        margin: 0;\n        padding: 0;\n      }\n    </style>\n    <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-3.6.1.min.js"></script>\n    <script type="text/javascript">\n        Bokeh.set_log_level("info");\n    </script>\n  </head>\n  <body>\n    <div id="f0a86e3b-74df-4f2c-a395-e9669ab90e03" data-root-id="p23085" style="display: contents;"></div>\n  \n    <script type="application/json" id="ec151569-8f8d-4c98-ac24-bcdf0ce11ac5">\n      {"52e17bdd-6795-4092-8917-677256f3eddd":{"version":"3.6.1","title":"Bokeh Application","roots":[{"type":"object","name":"Figure","id":"p23085","attributes":{"name":"task_stream","sizing_mode":"stretch_both","x_range":{"type":"object","name":"DataRange1d","id":"p23083","attributes":{"range_padding":0}},"y_range":{"type":"object","name":"DataRange1d","id":"p23084","attributes":{"range_padding":0}},"x_scale":{"type":"object","name":"LinearScale","id":"p23095"},"y_scale":{"type":"object",...bcdf0ce11ac5\').textContent;\n              const render_items = [{"docid":"52e17bdd-6795-4092-8917-677256f3eddd","roots":{"p23085":"f0a86e3b-74df-4f2c-a395-e9669ab90e03"},"root_ids":["p23085"]}];\n              root.Bokeh.embed.embed_items(docs_json, render_items);\n              }\n              if (root.Bokeh !== undefined) {\n                embed_document(root);\n              } else {\n                let attempts = 0;\n                const timer = setInterval(function(root) {\n                  if (root.Bokeh !== undefined) {\n                    clearInterval(timer);\n                    embed_document(root);\n                  } else {\n                    attempts++;\n                    if (attempts > 100) {\n                      clearInterval(timer);\n                      console.log("Bokeh: ERROR: Unable to run BokehJS code because BokehJS library is missing");\n                    }\n                  }\n                }, 10, root)\n              }\n            })(window);\n          });\n        };\n        if (document.readyState != "loading") fn();\n        else document.addEventListener("DOMContentLoaded", fn);\n      })();\n    </script>\n  </body>\n</html>'

distributed\diagnostics\tests\test_task_stream.py:159: AssertionError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[p2p-tasks] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 103353343764f68a051e1ccb8959794b failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '103353343764f68a051e1ccb8959794b'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and… shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43009', workers: 0, cores: 0, tasks: 0>
config_value = 'tasks', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:35513', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42035', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.98768471, 0.17582926, 0.92605057, 0.16072492, 0.16546501,
        0.28120577, 0.43820435, 0.02956598, 0.3916..., 0.98722643, 0.53573731, 0.70246876, 0.39640673,
        0.76411575, 0.65548619, 0.35407073, 0.19018565, 0.25111724]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 103353343764f68a051e1ccb8959794b failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[p2p-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7139bca66b75640feaa6b48f75426176 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '7139bca66b75640feaa6b48f75426176'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…   shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45201', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:34133', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39907', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.67263164, 0.99099891, 0.23391163, 0.0527319 , 0.37175087,
        0.10399512, 0.30617677, 0.6728578 , 0.3489..., 0.36582142, 0.29896918, 0.57915701, 0.61579285,
        0.63909114, 0.30842371, 0.44911196, 0.22216101, 0.92145394]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 7139bca66b75640feaa6b48f75426176 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[p2p-None] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 7ab19aa6fa0aa500648d26d50ba3cc34 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '7ab19aa6fa0aa500648d26d50ba3cc34'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45609', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = 'p2p'
ws = (<Worker 'tcp://127.0.0.1:37985', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39669', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.67946287, 0.49174119, 0.57030695, 0.28143769, 0.93976304,
        0.30630708, 0.070039  , 0.95549053, 0.1707..., 0.27496663, 0.06050454, 0.44717105, 0.51656972,
        0.10041995, 0.26876562, 0.52153445, 0.59627235, 0.26186335]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 7ab19aa6fa0aa500648d26d50ba3cc34 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[None-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P f2e1cf610172a1187382e9e34d16198d failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'f2e1cf610172a1187382e9e34d16198d'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36603', workers: 0, cores: 0, tasks: 0>
config_value = 'p2p', keyword = None
ws = (<Worker 'tcp://127.0.0.1:46485', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40951', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.76190721, 0.18421237, 0.32989737, 0.72996263, 0.7042225 ,
        0.62118023, 0.80522491, 0.4608107 , 0.0737..., 0.61995515, 0.14544883, 0.11477718, 0.08683625,
        0.15604437, 0.68832368, 0.58877247, 0.85997467, 0.76332141]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P f2e1cf610172a1187382e9e34d16198d failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_configuration[None-None] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P e0eda9e4c6daa20db3080bf5f59ffb75 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'e0eda9e4c6daa20db3080bf5f59ffb75'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:41971', workers: 0, cores: 0, tasks: 0>
config_value = None, keyword = None
ws = (<Worker 'tcp://127.0.0.1:40879', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:46379', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[3.12737881e-01, 4.60867855e-01, 9.73577295e-01, 3.15562528e-01,
        9.61726398e-01, 7.24117584e-01, 6.9431...e-01,
        4.99845514e-01, 7.43740630e-01, 5.46473780e-01, 4.42117179e-02,
        1.22833879e-01, 3.27800041e-01]])

    @pytest.mark.parametrize("config_value", ["tasks", "p2p", None])
    @pytest.mark.parametrize("keyword", ["tasks", "p2p", None])
    @gen_cluster(client=True)
    async def test_rechunk_configuration(c, s, *ws, config_value, keyword):
        """Try rechunking a random 1d matrix
    
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_1d
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(10, 1))
        new = ((1,) * 10, (10,))
        config = {"array.rechunk.method": config_value} if config_value is not None else {}
        with dask.config.set(config):
            x2 = rechunk(x, chunks=new, method=keyword)
        expected_algorithm = keyword if keyword is not None else config_value
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        elif expected_algorithm == "tasks":
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
        # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic)
        else:
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P e0eda9e4c6daa20db3080bf5f59ffb75 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_heuristic[new0-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 90af7b3bcf0d14cd2d09aae30fd664cf failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '90af7b3bcf0d14cd2d09aae30fd664cf'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…ate_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:46527', workers: 0, cores: 0, tasks: 0>
a = array([[0.09938915, 0.49322808, 0.60598825, ..., 0.4802486 , 0.23827154,
        0.81719055],
       [0.5558567 , 0.27...52,
        0.69894025],
       [0.25175323, 0.82966441, 0.95636455, ..., 0.76615423, 0.78213394,
        0.65567332]])
b = <Worker 'tcp://127.0.0.1:38815', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((1, 1, 1, 1, 1, 1, ...), (100,)), expected_algorithm = 'p2p'

    @pytest.mark.parametrize(
        ["new", "expected_algorithm"],
        [
            # All-to-all rechunking defaults to P2P
            (((1,) * 100, (100,)), "p2p"),
            # Localized rechunking defaults to tasks
            (((50, 50), (2,) * 50), "tasks"),
            # Less local rechunking first defaults to tasks,
            (((25, 25, 25, 25), (4,) * 25), "tasks"),
            # then switches to p2p
            (((10,) * 10, (10,) * 10), "p2p"),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
        a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
        x = da.from_array(a, chunks=(100, 1))
        x2 = rechunk(x, chunks=new)
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        else:
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:239: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 90af7b3bcf0d14cd2d09aae30fd664cf failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_heuristic[new3-p2p] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 8bebd7a505fd78dd1b1dcb229dea41a8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '8bebd7a505fd78dd1b1dcb229dea41a8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38873', workers: 0, cores: 0, tasks: 0>
a = array([[0.17099363, 0.84982634, 0.68269121, ..., 0.58142871, 0.20141963,
        0.70392131],
       [0.93343985, 0.96...35,
        0.77540941],
       [0.75007872, 0.23444326, 0.77904639, ..., 0.47615695, 0.47363149,
        0.06353038]])
b = <Worker 'tcp://127.0.0.1:35891', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>
new = ((10, 10, 10, 10, 10, 10, ...), (10, 10, 10, 10, 10, 10, ...))
expected_algorithm = 'p2p'

    @pytest.mark.parametrize(
        ["new", "expected_algorithm"],
        [
            # All-to-all rechunking defaults to P2P
            (((1,) * 100, (100,)), "p2p"),
            # Localized rechunking defaults to tasks
            (((50, 50), (2,) * 50), "tasks"),
            # Less local rechunking first defaults to tasks,
            (((25, 25, 25, 25), (4,) * 25), "tasks"),
            # then switches to p2p
            (((10,) * 10, (10,) * 10), "p2p"),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm):
        a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100))
        x = da.from_array(a, chunks=(100, 1))
        x2 = rechunk(x, chunks=new)
        if expected_algorithm == "p2p":
            assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__())
        else:
            assert not any(
                key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()
            )
    
        assert x2.chunks == new
>       assert np.all(await c.compute(x2) == a)

distributed/shuffle/tests/test_rechunk.py:239: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 8bebd7a505fd78dd1b1dcb229dea41a8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_cull_p2p_rechunk_independent_partitions (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
assert 57 < (228 / 4)
 +  where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n)
 +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
 +  and   228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n)
 +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35601', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:34703', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:44387', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[[0.599171  , 0.17502798, 0.88209158, 0.21115568, 0.46142343,
         0.03521772, 0.44125943, 0.5110587 , 0.64...0.8850746 , 0.48018381, 0.09855244, 0.71004424,
         0.51442559, 0.78168598, 0.52234215, 0.70874625, 0.70225798]]])
x = dask.array<array, shape=(10, 10, 10), dtype=float64, chunksize=(1, 5, 1), chunktype=numpy.ndarray>
new = (5, 1, -1)

    @gen_cluster(client=True)
    async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws):
        a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10))
        x = da.from_array(a, chunks=(1, 5, 1))
        new = (5, 1, -1)
        rechunked = rechunk(x, chunks=new, method="p2p")
        (dsk,) = dask.optimize(rechunked)
        culled = rechunked[:5, :2]
        (dsk_culled,) = dask.optimize(culled)
    
        # The culled graph requires only 1/2 of the input tasks
        n_inputs = len(
            [1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")]
        )
        n_culled_inputs = len(
            [
                1
                for key in dsk_culled.dask.get_all_dependencies()
                if key[0].startswith("array-")
            ]
        )
        assert n_culled_inputs == n_inputs / 4
        # The culled graph should also have less than 1/4 the tasks
>       assert len(dsk_culled.dask) < len(dsk.dask) / 4
E       assert 57 < (228 / 4)
E        +  where 57 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n)
E        +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fda72900>\n 0. getitem-7d63926232ec9c9ad09ece51f431ba62\n = dask.array<getitem, shape=(5, 2, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask
E        +  and   228 = len(HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n)
E        +    where HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7fc8fd666e40>\n 0. rechunk-p2p-b4936ca757976705fec9ffc12017b806\n = dask.array<rechunk-p2p, shape=(10, 10, 10), dtype=float64, chunksize=(5, 1, 10), chunktype=numpy.ndarray>.dask

distributed/shuffle/tests/test_rechunk.py:265: AssertionError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_expand (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 16159121174d0a3cf3a9d8fa1f4d944f failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '16159121174d0a3cf3a9d8fa1f4d944f'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…ib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
>           yield

distributed/shuffle/_core.py:523: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
    return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
    shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39835', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:45529', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:46747', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = array([[0.03029662, 0.67497307, 0.54696325, 0.5664905 , 0.27160949,
        0.56769018, 0.65528007, 0.42471041, 0.7749..., 0.04748013, 0.53053804, 0.17881593, 0.05720855,
        0.00615196, 0.73041112, 0.96641799, 0.98946603, 0.54375914]])
x = dask.array<array, shape=(10, 10), dtype=float64, chunksize=(5, 5), chunktype=numpy.ndarray>
y = dask.array<rechunk-p2p, shape=(10, 10), dtype=float64, chunksize=(3, 3), chunktype=numpy.ndarray>

    @gen_cluster(client=True)
    async def test_rechunk_expand(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_expand
        """
        a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10))
        x = da.from_array(a, chunks=(5, 5))
        y = x.rechunk(chunks=((3, 3, 3, 1), (3, 3, 3, 1)), method="p2p")
>       assert np.all(await c.compute(y) == a)

distributed/shuffle/tests/test_rechunk.py:377: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 16159121174d0a3cf3a9d8fa1f4d944f failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

All 11 runs failed: test_rechunk_expand2 (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-numpy-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P bc0b6e3c445f0bbb5a37bf95be33b322 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'bc0b6e3c445f0bbb5a37bf95be33b322'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…le_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
>           yield

distributed/shuffle/_core.py:523: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/shuffle/_rechunk.py:170: in rechunk_transfer
    return get_worker_plugin().add_partition(
distributed/shuffle/_worker_plugin.py:348: in add_partition
    shuffle_run = self.get_or_create_shuffle(id)
distributed/shuffle/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43843', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:37269', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:33729', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
a = 3, b = 2
orig = array([[0.42729343, 0.06459062, 0.34592701],
       [0.86826575, 0.07615254, 0.2244809 ],
       [0.0676913 , 0.43340265, 0.93535141]])

    @gen_cluster(client=True)
    async def test_rechunk_expand2(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_expand2
        """
        (a, b) = (3, 2)
        orig = np.random.default_rng().uniform(0, 1, a**b).reshape((a,) * b)
        for off, off2 in product(range(1, a - 1), range(1, a - 1)):
            old = ((a - off, off),) * b
            x = da.from_array(orig, chunks=old)
            new = ((a - off2, off2),) * b
            assert np.all(await c.compute(x.rechunk(chunks=new, method="p2p")) == orig)
            if a - off - off2 > 0:
                new = ((off, a - off2 - off, off2),) * b
>               y = await c.compute(x.rechunk(chunks=new, method="p2p"))

distributed/shuffle/tests/test_rechunk.py:396: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P bc0b6e3c445f0bbb5a37bf95be33b322 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_unknown_from_pandas (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 427d2d9ead2057e7af39f331aad1826a failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '427d2d9ead2057e7af39f331aad1826a'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…le/_worker_plugin.py:411: in get_or_create_shuffle
    return sync(
distributed/utils.py:439: in sync
    raise error
distributed/utils.py:413: in f
    result = yield future
../../../miniconda3/envs/dask-distributed/lib/python3.12/site-packages/tornado/gen.py:766: in run
    value = future.result()
distributed/shuffle/_worker_plugin.py:145: in get_or_create
    shuffle_run = await self._refresh(
distributed/shuffle/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33963', workers: 0, cores: 0, tasks: 0>
ws = (<Worker 'tcp://127.0.0.1:41295', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39915', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
pd = <module 'pandas' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/pandas/__init__.py'>
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>
arr = array([[-1.10870194e-01,  1.24996567e+00,  9.07049911e-02,
         1.24946836e+00, -5.82149888e-01,  1.31636619e+00,
... 9.54350494e-02, -1.54228652e-01,
         1.41285044e+00, -3.73567619e-01,  2.63043451e-01,
        -1.00381024e+00]])

    @gen_cluster(client=True)
    async def test_rechunk_unknown_from_pandas(c, s, *ws):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas
        """
        pd = pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
    
        arr = np.random.default_rng().standard_normal((50, 10))
        x = dd.from_pandas(pd.DataFrame(arr), 2).values
        result = x.rechunk((None, (5, 5)), method="p2p")
        assert np.isnan(x.chunks[0]).all()
        assert np.isnan(result.chunks[0]).all()
        assert result.chunks[1] == (5, 5)
        expected = da.from_array(arr, chunks=((25, 25), (10,))).rechunk(
            (None, (5, 5)), method="p2p"
        )
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:706: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 427d2d9ead2057e7af39f331aad1826a failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x0-chunks0] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'd9513d99181485ac4237b30c4b28d743'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…rker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35113', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:40495', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40935', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x1-chunks1] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 0e8a9f04ceb745bcef4ba25a6e0f2f79 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '0e8a9f04ceb745bcef4ba25a6e0f2f79'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:35473', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:44921', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:40681', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 0e8a9f04ceb745bcef4ba25a6e0f2f79 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x2-chunks2] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: 'd9513d99181485ac4237b30c4b28d743'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:36599', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(50, 10), dtype=float64, chunksize=(25, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:44699', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:39473', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P d9513d99181485ac4237b30c4b28d743 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x3-chunks3] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '871d68ad5c66483062239bd5af22c9b8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…rker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33749', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:34851', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:33885', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x4-chunks4] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '871d68ad5c66483062239bd5af22c9b8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:38033', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:44575', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:36277', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x5-chunks5] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '871d68ad5c66483062239bd5af22c9b8'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40915', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(100, 10), dtype=float64, chunksize=(5, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:45271', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:32875', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 871d68ad5c66483062239bd5af22c9b8 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x6-chunks6] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '01992e7a5ca3de072c0c2984a1548c6e'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…rker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43175', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:36981', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:42401', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x7-chunks7] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '01992e7a5ca3de072c0c2984a1548c6e'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:45757', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:40155', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:45267', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x8-chunks8] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '01992e7a5ca3de072c0c2984a1548c6e'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:43977', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 10), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:44945', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:35293', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 01992e7a5ca3de072c0c2984a1548c6e failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x9-chunks9] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '9f169a5772161f491fce47676e173277'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…orker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:40785', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = (None, 5)
ws = (<Worker 'tcp://127.0.0.1:39757', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41009', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x10-chunks10] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '9f169a5772161f491fce47676e173277'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…/_worker_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:33929', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = {1: 5}
ws = (<Worker 'tcp://127.0.0.1:35765', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:34181', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError

Check warning on line 0 in distributed.shuffle.tests.test_rechunk

See this annotation in the file changed.

@github-actions github-actions / Unit Test Results

10 out of 11 runs failed: test_rechunk_with_fully_unknown_dimension[x11-chunks11] (distributed.shuffle.tests.test_rechunk)

artifacts/macos-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_expr-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.10-no_queue-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-3.12-default-notci1/pytest.xml [took 0s]
artifacts/ubuntu-latest-mindeps-pandas-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.10-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.11-default-notci1/pytest.xml [took 0s]
artifacts/windows-latest-3.12-default-notci1/pytest.xml [took 0s]
Raw output
RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase
from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
            assert isinstance(barrier_task_spec, P2PBarrierTask)
            return barrier_task_spec.spec
    
        def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec:
            # FIXME: The current implementation relies on the barrier task to be
            # known by its name. If the name has been mangled, we cannot guarantee
            # that the shuffle works as intended and should fail instead.
            self._raise_if_barrier_unknown(shuffle_id)
            self._raise_if_task_not_processing(key)
            spec = self._retrieve_spec(shuffle_id)
            worker_for = self._calculate_worker_for(spec)
            self._ensure_output_tasks_are_non_rootish(spec)
            state = spec.create_new_run(
                worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id
            )
            self.active_shuffles[shuffle_id] = state
            self._shuffles[shuffle_id].add(state)
            state.participating_workers.add(worker)
            logger.warning(
                "Shuffle %s initialized by task %r executed on worker %s",
                shuffle_id,
                key,
                worker,
            )
            return state.run_spec
    
        def get_or_create(
            self,
            shuffle_id: ShuffleId,
            key: Key,
            worker: str,
        ) -> RunSpecMessage | ErrorMessage:
            try:
>               run_spec = self._get(shuffle_id, worker)

distributed/shuffle/_scheduler_plugin.py:229: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
>           state = self.active_shuffles[id]
E           KeyError: '9f169a5772161f491fce47676e173277'

distributed/shuffle/_scheduler_plugin.py:190: KeyError

During handling of the above exception, another exception occurred:

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and…_plugin.py:222: in _refresh
    result = await self._fetch(shuffle_id=shuffle_id, key=key)
distributed/shuffle/_worker_plugin.py:190: in _fetch
    response = await self._plugin.worker.scheduler.shuffle_get_or_create(
distributed/core.py:1259: in send_recv_from_rpc
    return await send_recv(comm=comm, op=key, **kwargs)
distributed/core.py:1043: in send_recv
    raise exc.with_traceback(tb)
distributed/core.py:832: in _handle_comm
    result = handler(**msg)
distributed/shuffle/_scheduler_plugin.py:234: in get_or_create
    run_spec = self._create(shuffle_id, key, worker)
distributed/shuffle/_scheduler_plugin.py:205: in _create
    spec = self._retrieve_spec(shuffle_id)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import asyncio
    import contextlib
    import itertools
    import logging
    from collections import defaultdict
    from typing import TYPE_CHECKING, Any
    
    from dask.typing import Key
    
    from distributed.core import ErrorMessage, OKMessage, error_message
    from distributed.diagnostics.plugin import SchedulerPlugin
    from distributed.metrics import time
    from distributed.protocol.pickle import dumps
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._core import (
        P2PBarrierTask,
        RunSpecMessage,
        SchedulerShuffleState,
        ShuffleId,
        ShuffleRunSpec,
        ShuffleSpec,
        barrier_key,
        id_from_key,
    )
    from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError
    from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    from distributed.utils import log_errors
    
    if TYPE_CHECKING:
        from distributed.scheduler import (
            Recs,
            Scheduler,
            TaskState,
            TaskStateState,
            WorkerState,
        )
    
    logger = logging.getLogger(__name__)
    
    
    class ShuffleSchedulerPlugin(SchedulerPlugin):
        """
        Shuffle plugin for the scheduler
        This coordinates the individual worker plugins to ensure correctness
        and collects heartbeat messages for the dashboard.
        See Also
        --------
        ShuffleWorkerPlugin
        """
    
        scheduler: Scheduler
        active_shuffles: dict[ShuffleId, SchedulerShuffleState]
        heartbeats: defaultdict[ShuffleId, dict]
        _shuffles: defaultdict[ShuffleId, set[SchedulerShuffleState]]
        _archived_by_stimulus: defaultdict[str, set[SchedulerShuffleState]]
        _shift_counter: itertools.count[int]
    
        def __init__(self, scheduler: Scheduler):
            self.scheduler = scheduler
            self.scheduler.handlers.update(
                {
                    "shuffle_barrier": self.barrier,
                    "shuffle_get": self.get,
                    "shuffle_get_or_create": self.get_or_create,
                    "shuffle_restrict_task": self.restrict_task,
                }
            )
            self.heartbeats = defaultdict(lambda: defaultdict(dict))
            self.active_shuffles = {}
            self.scheduler.add_plugin(self, name="shuffle")
            self._shuffles = defaultdict(set)
            self._archived_by_stimulus = defaultdict(set)
            self._shift_counter = itertools.count()
    
        async def start(self, scheduler: Scheduler) -> None:
            worker_plugin = ShuffleWorkerPlugin()
            await self.scheduler.register_worker_plugin(
                None, dumps(worker_plugin), name="shuffle", idempotent=False
            )
    
        def shuffle_ids(self) -> set[ShuffleId]:
            return set(self.active_shuffles)
    
        async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
            shuffle = self.active_shuffles[id]
            if shuffle.run_id != run_id:
                raise ValueError(f"{run_id=} does not match {shuffle}")
            if not consistent:
                logger.warning(
                    "Shuffle %s restarted due to data inconsistency during barrier",
                    shuffle.id,
                )
                return self._restart_shuffle(
                    shuffle.id,
                    self.scheduler,
                    stimulus_id=f"p2p-barrier-inconsistent-{time()}",
                )
            msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
            workers = list(shuffle.participating_workers)
            no_progress = 0
            while workers:
                res = await self.scheduler.broadcast(
                    msg=msg,
                    workers=workers,
                    on_error="return",
                )
                before = len(workers)
                workers = []
                for w, r in res.items():
                    if r is None:
                        continue
                    if isinstance(r, OSError):
                        workers.append(w)
                    else:
                        raise RuntimeError(
                            f"Unexpected error encountered during P2P barrier: {r!r}"
                        )
                workers = [w for w, r in res.items() if r is not None]
                if workers:
                    logger.warning(
                        "Failure during broadcast of %s, retrying.",
                        shuffle.id,
                    )
                    if any(w not in self.scheduler.workers for w in workers):
                        if not shuffle.archived:
                            # If the shuffle is not yet archived, this could mean that the barrier task fails
                            # before the P2P restarting mechanism can kick in.
                            raise P2PIllegalStateError(
                                "Expected shuffle to be archived if participating worker is not known by scheduler"
                            )
                        raise RuntimeError(
                            f"Worker {workers} left during shuffle {shuffle}"
                        )
                    await asyncio.sleep(0.1)
                    if len(workers) == before:
                        no_progress += 1
                        if no_progress >= 3:
                            raise RuntimeError(
                                f"""Broadcast not making progress for {shuffle}.
                                Aborting. This is possibly due to overloaded
                                workers. Increasing config
                                `distributed.comm.timeouts.connect` timeout may
                                help."""
                            )
    
        def restrict_task(
            self, id: ShuffleId, run_id: int, key: Key, worker: str
        ) -> OKMessage | ErrorMessage:
            try:
                shuffle = self.active_shuffles[id]
                if shuffle.run_id > run_id:
                    raise P2PConsistencyError(
                        f"Request stale, expected {run_id=} for {shuffle}"
                    )
                elif shuffle.run_id < run_id:
                    raise P2PConsistencyError(
                        f"Request invalid, expected {run_id=} for {shuffle}"
                    )
                ts = self.scheduler.tasks[key]
                self._set_restriction(ts, worker)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        def heartbeat(self, ws: WorkerState, data: dict) -> None:
            for shuffle_id, d in data.items():
                if shuffle_id in self.shuffle_ids():
                    self.heartbeats[shuffle_id][ws.address].update(d)
    
        def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage:
            try:
                try:
                    run_spec = self._get(id, worker)
                    return {"status": "OK", "run_spec": ToPickle(run_spec)}
                except KeyError as e:
                    raise P2PConsistencyError(
                        f"No active shuffle with {id=!r} found"
                    ) from e
            except P2PConsistencyError as e:
                return error_message(e)
    
        def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec:
            if worker not in self.scheduler.workers:
                # This should never happen
                raise P2PConsistencyError(
                    f"Scheduler is unaware of this worker {worker!r}"
                )  # pragma: nocover
            state = self.active_shuffles[id]
            state.participating_workers.add(worker)
            return state.run_spec
    
        def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec:
            barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec
>           assert isinstance(barrier_task_spec, P2PBarrierTask)
E           AssertionError

distributed/shuffle/_scheduler_plugin.py:196: AssertionError

The above exception was the direct cause of the following exception:

c = <Client: No scheduler connected>
s = <Scheduler 'tcp://127.0.0.1:39469', workers: 0, cores: 0, tasks: 0>
x = dask.array<ones_like, shape=(10, 10), dtype=float64, chunksize=(10, 2), chunktype=numpy.ndarray>
chunks = (None, (5, 5))
ws = (<Worker 'tcp://127.0.0.1:43215', name: 0, status: closed, stored: 0, running: 0/1, ready: 0, comm: 0, waiting: 0>, <Worker 'tcp://127.0.0.1:41579', name: 1, status: closed, stored: 0, running: 0/2, ready: 0, comm: 0, waiting: 0>)
dd = <module 'dask.dataframe' from '/home/runner/miniconda3/envs/dask-distributed/lib/python3.12/site-packages/dask/dataframe/__init__.py'>

    @pytest.mark.parametrize(
        "x, chunks",
        [
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, 5)),
            (da.ones(shape=(50, 10), chunks=(25, 10)), {1: 5}),
            (da.ones(shape=(50, 10), chunks=(25, 10)), (None, (5, 5))),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, 5)),
            (da.ones(shape=(100, 10), chunks=(5, 10)), {1: 5}),
            (da.ones(shape=(100, 10), chunks=(5, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 10)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 10)), (None, (5, 5))),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, 5)),
            (da.ones(shape=(10, 10), chunks=(10, 2)), {1: 5}),
            (da.ones(shape=(10, 10), chunks=(10, 2)), (None, (5, 5))),
        ],
    )
    @gen_cluster(client=True)
    async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks):
        """
        See Also
        --------
        dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension
        """
        pytest.importorskip("pandas")
        dd = pytest.importorskip("dask.dataframe")
        y = dd.from_array(x).values
        result = y.rechunk(chunks, method="p2p")
        expected = x.rechunk(chunks, method="p2p")
    
        assert_chunks_match(result.chunks, expected.chunks)
>       assert_eq(await c.compute(result), await c.compute(expected))

distributed/shuffle/tests/test_rechunk.py:757: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distributed/client.py:410: in _result
    raise exc.with_traceback(tb)
distributed/shuffle/_rechunk.py:169: in rechunk_transfer
    with handle_transfer_errors(id):
../../../miniconda3/envs/dask-distributed/lib/python3.12/contextlib.py:158: in __exit__
    self.gen.throw(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    from __future__ import annotations
    
    import abc
    import asyncio
    import contextlib
    import itertools
    import pickle
    import time
    from collections.abc import (
        Callable,
        Coroutine,
        Generator,
        Hashable,
        Iterable,
        Iterator,
        Sequence,
    )
    from concurrent.futures import ThreadPoolExecutor
    from dataclasses import dataclass, field
    from enum import Enum
    from functools import partial
    from pathlib import Path
    from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
    
    from tornado.ioloop import IOLoop
    
    import dask.config
    from dask._task_spec import Task
    from dask.core import flatten
    from dask.typing import Key
    from dask.utils import parse_bytes, parse_timedelta
    
    from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message
    from distributed.exceptions import Reschedule
    from distributed.metrics import context_meter, thread_time
    from distributed.protocol import to_serialize
    from distributed.protocol.serialize import ToPickle
    from distributed.shuffle._comms import CommShardsBuffer
    from distributed.shuffle._disk import DiskShardsBuffer
    from distributed.shuffle._exceptions import (
        P2PConsistencyError,
        P2POutOfDiskError,
        ShuffleClosedError,
    )
    from distributed.shuffle._limiter import ResourceLimiter
    from distributed.shuffle._memory import MemoryShardsBuffer
    from distributed.utils import run_in_executor_with_context, sync
    from distributed.utils_comm import retry
    
    if TYPE_CHECKING:
        # TODO import from typing (requires Python >=3.10)
        from typing_extensions import ParamSpec, TypeAlias
    
        _P = ParamSpec("_P")
    
        # circular dependencies
        from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
    
    ShuffleId = NewType("ShuffleId", str)
    NDIndex: TypeAlias = tuple[int, ...]
    
    
    _T_partition_id = TypeVar("_T_partition_id")
    _T_partition_type = TypeVar("_T_partition_type")
    _T = TypeVar("_T")
    
    
    class RunSpecMessage(OKMessage):
        run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
    
    
    class ShuffleRun(Generic[_T_partition_id, _T_partition_type]):
        id: ShuffleId
        run_id: int
        span_id: str | None
        local_address: str
        executor: ThreadPoolExecutor
        rpc: Callable[[str], PooledRPCCall]
        digest_metric: Callable[[Hashable, float], None]
        scheduler: PooledRPCCall
        closed: bool
        _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer
        _comm_buffer: CommShardsBuffer
        received: set[_T_partition_id]
        total_recvd: int
        start_time: float
        _exception: Exception | None
        _closed_event: asyncio.Event
        _loop: IOLoop
    
        RETRY_COUNT: int
        RETRY_DELAY_MIN: float
        RETRY_DELAY_MAX: float
    
        def __init__(
            self,
            id: ShuffleId,
            run_id: int,
            span_id: str | None,
            local_address: str,
            directory: str,
            executor: ThreadPoolExecutor,
            rpc: Callable[[str], PooledRPCCall],
            digest_metric: Callable[[Hashable, float], None],
            scheduler: PooledRPCCall,
            memory_limiter_disk: ResourceLimiter,
            memory_limiter_comms: ResourceLimiter,
            disk: bool,
            loop: IOLoop,
        ):
            self.id = id
            self.run_id = run_id
            self.span_id = span_id
            self.local_address = local_address
            self.executor = executor
            self.rpc = rpc
            self.digest_metric = digest_metric
            self.scheduler = scheduler
            self.closed = False
    
            # Initialize buffers and start background tasks
            # Don't log metrics issued by the background tasks onto the dask task that
            # spawned this object
            with context_meter.clear_callbacks():
                with self._capture_metrics("background-disk"):
                    if disk:
                        self._disk_buffer = DiskShardsBuffer(
                            directory=directory,
                            read=self.read,
                            memory_limiter=memory_limiter_disk,
                        )
                    else:
                        self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
    
                with self._capture_metrics("background-comms"):
                    max_message_size = parse_bytes(
                        dask.config.get("distributed.p2p.comm.message-bytes-limit")
                    )
                    concurrency_limit = dask.config.get("distributed.p2p.comm.concurrency")
                    self._comm_buffer = CommShardsBuffer(
                        send=self.send,
                        max_message_size=max_message_size,
                        memory_limiter=memory_limiter_comms,
                        concurrency_limit=concurrency_limit,
                    )
    
            # TODO: reduce number of connections to number of workers
            # MultiComm.max_connections = min(10, n_workers)
    
            self.transferred = False
            self.received = set()
            self.total_recvd = 0
            self.start_time = time.time()
            self._exception = None
            self._closed_event = asyncio.Event()
            self._loop = loop
    
            self.RETRY_COUNT = dask.config.get("distributed.p2p.comm.retry.count")
            self.RETRY_DELAY_MIN = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.min"), default="s"
            )
            self.RETRY_DELAY_MAX = parse_timedelta(
                dask.config.get("distributed.p2p.comm.retry.delay.max"), default="s"
            )
    
        def __repr__(self) -> str:
            return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"
    
        def __hash__(self) -> int:
            return self.run_id
    
        @contextlib.contextmanager
        def _capture_metrics(self, where: str) -> Iterator[None]:
            """Capture context_meter metrics as
    
                {('p2p', <span id>, 'foreground|background...', label, unit): value}
    
            **Note 1:** When the metric is not logged by a background task
            (where='foreground'), this produces a duplicated metric under
    
                {('execute', <span id>, <task prefix>, label, unit): value}
    
            This is by design so that one can have a holistic view of the whole shuffle
            process.
    
            **Note 2:** We're immediately writing to Worker.digests.
            We don't temporarily store metrics under ShuffleRun as we would lose those
            recorded between the heartbeat and when the ShuffleRun object is deleted at the
            end of a run.
            """
    
            def callback(label: Hashable, value: float, unit: str) -> None:
                if not isinstance(label, tuple):
                    label = (label,)
                if isinstance(label[0], str) and label[0].startswith("p2p-"):
                    label = (label[0][len("p2p-") :], *label[1:])
                name = ("p2p", self.span_id, where, *label, unit)
    
                self.digest_metric(name, value)
    
            with context_meter.add_callback(callback, allow_offload="background" in where):
                yield
    
        async def barrier(self, run_ids: Sequence[int]) -> int:
            self.raise_if_closed()
            consistent = all(run_id == self.run_id for run_id in run_ids)
            # TODO: Consider broadcast pinging once when the shuffle starts to warm
            # up the comm pool on scheduler side
            await self.scheduler.shuffle_barrier(
                id=self.id, run_id=self.run_id, consistent=consistent
            )
            return self.run_id
    
        async def _send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            self.raise_if_closed()
            return await self.rpc(address).shuffle_receive(
                data=to_serialize(shards),
                shuffle_id=self.id,
                run_id=self.run_id,
            )
    
        async def send(
            self, address: str, shards: list[tuple[_T_partition_id, Any]]
        ) -> OKMessage | ErrorMessage:
            if _mean_shard_size(shards) < 65536:
                # Don't send buffers individually over the tcp comms.
                # Instead, merge everything into an opaque bytes blob, send it all at once,
                # and unpickle it on the other side.
                # Performance tests informing the size threshold:
                # https://github.com/dask/distributed/pull/8318
                shards_or_bytes: list | bytes = pickle.dumps(shards)
            else:
                shards_or_bytes = shards
    
            def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]:
                return self._send(address, shards_or_bytes)
    
            return await retry(
                _send,
                count=self.RETRY_COUNT,
                delay_min=self.RETRY_DELAY_MIN,
                delay_max=self.RETRY_DELAY_MAX,
            )
    
        async def offload(
            self, func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
        ) -> _T:
            self.raise_if_closed()
            with context_meter.meter("offload"):
                return await run_in_executor_with_context(
                    self.executor, func, *args, **kwargs
                )
    
        def heartbeat(self) -> dict[str, Any]:
            comm_heartbeat = self._comm_buffer.heartbeat()
            comm_heartbeat["read"] = self.total_recvd
            return {
                "disk": self._disk_buffer.heartbeat(),
                "comm": comm_heartbeat,
                "start": self.start_time,
            }
    
        async def _write_to_comm(
            self, data: dict[str, tuple[_T_partition_id, Any]]
        ) -> None:
            self.raise_if_closed()
            await self._comm_buffer.write(data)
    
        async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
            self.raise_if_closed()
            await self._disk_buffer.write(
                {"_".join(str(i) for i in k): v for k, v in data.items()}
            )
    
        def raise_if_closed(self) -> None:
            if self.closed:
                if self._exception:
                    raise self._exception
                raise ShuffleClosedError(f"{self} has already been closed")
    
        async def inputs_done(self) -> None:
            self.raise_if_closed()
            self.transferred = True
            await self._flush_comm()
            try:
                self._comm_buffer.raise_on_exception()
            except Exception as e:
                self._exception = e
                raise
    
        async def _flush_comm(self) -> None:
            self.raise_if_closed()
            await self._comm_buffer.flush()
    
        async def flush_receive(self) -> None:
            self.raise_if_closed()
            await self._disk_buffer.flush()
    
        async def close(self) -> None:
            if self.closed:  # pragma: no cover
                await self._closed_event.wait()
                return
    
            self.closed = True
            await self._comm_buffer.close()
            await self._disk_buffer.close()
            self._closed_event.set()
    
        def fail(self, exception: Exception) -> None:
            if not self.closed:
                self._exception = exception
    
        def _read_from_disk(self, id: NDIndex) -> list[Any]:  # TODO: Typing
            self.raise_if_closed()
            return self._disk_buffer.read("_".join(str(i) for i in id))
    
        async def receive(
            self, data: list[tuple[_T_partition_id, Any]] | bytes
        ) -> OKMessage | ErrorMessage:
            try:
                if isinstance(data, bytes):
                    # Unpack opaque blob. See send()
                    data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data))
                await self._receive(data)
                return {"status": "OK"}
            except P2PConsistencyError as e:
                return error_message(e)
    
        async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:
            assigned_worker = self._get_assigned_worker(i)
    
            if assigned_worker != self.local_address:
                result = await self.scheduler.shuffle_restrict_task(
                    id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
                )
                if result["status"] == "error":
                    raise RuntimeError(result["message"])
                assert result["status"] == "OK"
                raise Reschedule()
    
        @abc.abstractmethod
        def _get_assigned_worker(self, i: _T_partition_id) -> str:
            """Get the address of the worker assigned to the output partition"""
    
        @abc.abstractmethod
        async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
            """Receive shards belonging to output partitions of this shuffle run"""
    
        def add_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> int:
            self.raise_if_closed()
            if self.transferred:
                raise RuntimeError(f"Cannot add more partitions to {self}")
            # Log metrics both in the "execute" and in the "p2p" contexts
            self.validate_data(data)
            with self._capture_metrics("foreground"):
                with (
                    context_meter.meter("p2p-shard-partition-noncpu"),
                    context_meter.meter("p2p-shard-partition-cpu", func=thread_time),
                ):
                    shards = self._shard_partition(data, partition_id)
                sync(self._loop, self._write_to_comm, shards)
            return self.run_id
    
        @abc.abstractmethod
        def _shard_partition(
            self, data: _T_partition_type, partition_id: _T_partition_id
        ) -> dict[str, tuple[_T_partition_id, Any]]:
            """Shard an input partition by the assigned output workers"""
    
        def get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            self.raise_if_closed()
            sync(self._loop, self._ensure_output_worker, partition_id, key)
            if not self.transferred:
                raise RuntimeError("`get_output_partition` called before barrier task")
            sync(self._loop, self.flush_receive)
            with (
                # Log metrics both in the "execute" and in the "p2p" contexts
                self._capture_metrics("foreground"),
                context_meter.meter("p2p-get-output-noncpu"),
                context_meter.meter("p2p-get-output-cpu", func=thread_time),
            ):
                return self._get_output_partition(partition_id, key, **kwargs)
    
        @abc.abstractmethod
        def _get_output_partition(
            self, partition_id: _T_partition_id, key: Key, **kwargs: Any
        ) -> _T_partition_type:
            """Get an output partition to the shuffle run"""
    
        @abc.abstractmethod
        def read(self, path: Path) -> tuple[Any, int]:
            """Read shards from disk"""
    
        @abc.abstractmethod
        def deserialize(self, buffer: Any) -> Any:
            """Deserialize shards"""
    
        def validate_data(self, data: Any) -> None:
            """Validate payload data before shuffling"""
    
    
    def get_worker_plugin() -> ShuffleWorkerPlugin:
        from distributed import get_worker
    
        try:
            worker = get_worker()
        except ValueError as e:
            raise RuntimeError(
                "`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; "
                "please confirm that you've created a distributed Client and are submitting this computation through it."
            ) from e
        try:
            return worker.plugins["shuffle"]  # type: ignore
        except KeyError as e:
            raise RuntimeError(
                f"The worker {worker.address} does not have a P2P shuffle plugin."
            ) from e
    
    
    _BARRIER_PREFIX = "shuffle-barrier-"
    
    
    def barrier_key(shuffle_id: ShuffleId) -> str:
        return _BARRIER_PREFIX + shuffle_id
    
    
    def id_from_key(key: Key) -> ShuffleId | None:
        if not isinstance(key, str) or not key.startswith(_BARRIER_PREFIX):
            return None
        return ShuffleId(key[len(_BARRIER_PREFIX) :])
    
    
    class ShuffleType(Enum):
        DATAFRAME = "DataFrameShuffle"
        ARRAY_RECHUNK = "ArrayRechunk"
    
    
    @dataclass(frozen=True)
    class ShuffleRunSpec(Generic[_T_partition_id]):
        run_id: int = field(init=False, default_factory=partial(next, itertools.count(1)))
        spec: ShuffleSpec
        worker_for: dict[_T_partition_id, str]
        span_id: str | None
    
        @property
        def id(self) -> ShuffleId:
            return self.spec.id
    
    
    @dataclass(frozen=True)
    class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
        id: ShuffleId
        disk: bool
    
        @property
        @abc.abstractmethod
        def output_partitions(self) -> Generator[_T_partition_id]:
            """Output partitions"""
    
        @abc.abstractmethod
        def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
            """Pick a worker for a partition"""
    
        def create_new_run(
            self,
            worker_for: dict[_T_partition_id, str],
            span_id: str | None,
        ) -> SchedulerShuffleState:
            return SchedulerShuffleState(
                run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for, span_id=span_id),
                participating_workers=set(worker_for.values()),
            )
    
        @abc.abstractmethod
        def create_run_on_worker(
            self,
            run_id: int,
            span_id: str | None,
            worker_for: dict[_T_partition_id, str],
            plugin: ShuffleWorkerPlugin,
        ) -> ShuffleRun:
            """Create the new shuffle run on the worker."""
    
    
    @dataclass(eq=False)
    class SchedulerShuffleState(Generic[_T_partition_id]):
        run_spec: ShuffleRunSpec
        participating_workers: set[str]
        _archived_by: str | None = field(default=None, init=False)
        _failed: bool = False
    
        @property
        def id(self) -> ShuffleId:
            return self.run_spec.id
    
        @property
        def run_id(self) -> int:
            return self.run_spec.run_id
    
        @property
        def archived(self) -> bool:
            return self._archived_by is not None
    
        def __str__(self) -> str:
            return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"
    
        def __hash__(self) -> int:
            return hash(self.run_id)
    
    
    @contextlib.contextmanager
    def handle_transfer_errors(id: ShuffleId) -> Iterator[None]:
        try:
            yield
        except ShuffleClosedError:
            raise Reschedule()
        except P2PConsistencyError:
            raise
        except P2POutOfDiskError:
            raise
        except Exception as e:
>           raise RuntimeError(f"P2P {id} failed during transfer phase") from e
E           RuntimeError: P2P 9f169a5772161f491fce47676e173277 failed during transfer phase

distributed/shuffle/_core.py:531: RuntimeError