From f340f1853e0f7906f3580ea5dfd32c0a152ccf5d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 30 Oct 2024 17:18:36 +0100 Subject: [PATCH] Reduce P2P transfer task overhead (#8912) --- distributed/shuffle/_core.py | 60 +++++++++++ distributed/shuffle/_merge.py | 117 ++++++++++++++-------- distributed/shuffle/_rechunk.py | 25 ++--- distributed/shuffle/_scheduler_plugin.py | 23 +++-- distributed/shuffle/_shuffle.py | 56 +++++------ distributed/shuffle/_worker_plugin.py | 37 +++---- distributed/shuffle/tests/test_merge.py | 10 +- distributed/shuffle/tests/test_shuffle.py | 2 +- 8 files changed, 208 insertions(+), 122 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 151436212c..e0528440de 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -25,7 +25,9 @@ from tornado.ioloop import IOLoop import dask.config +from dask._task_spec import Task, _inline_recursively from dask.core import flatten +from dask.sizeof import sizeof from dask.typing import Key from dask.utils import parse_bytes, parse_timedelta @@ -575,3 +577,61 @@ def p2p_barrier(id: ShuffleId, run_ids: list[int]) -> int: raise except Exception as e: raise RuntimeError(f"P2P {id} failed during barrier phase") from e + + +class P2PBarrierTask(Task): + spec: ShuffleSpec + + __slots__ = tuple(__annotations__) + + def __init__( + self, + key: Any, + func: Callable[..., Any], + /, + *args: Any, + spec: ShuffleSpec, + **kwargs: Any, + ): + self.spec = spec + super().__init__(key, func, *args, **kwargs) + + def copy(self) -> P2PBarrierTask: + self.unpack() + assert self.func is not None + return P2PBarrierTask( + self.key, self.func, *self.args, spec=self.spec, **self.kwargs + ) + + def __sizeof__(self) -> int: + return super().__sizeof__() + sizeof(self.spec) + + def __repr__(self) -> str: + return f"P2PBarrierTask({self.key!r})" + + def inline(self, dsk: dict[Key, Any]) -> P2PBarrierTask: + self.unpack() + new_args = _inline_recursively(self.args, dsk) + new_kwargs = _inline_recursively(self.kwargs, dsk) + assert self.func is not None + return P2PBarrierTask( + self.key, self.func, *new_args, spec=self.spec, **new_kwargs + ) + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + state["spec"] = self.spec + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + super().__setstate__(state) + self.spec = state["spec"] + + def __eq__(self, value: object) -> bool: + if not isinstance(value, P2PBarrierTask): + return False + if not super().__eq__(value): + return False + if self.spec != value.spec: + return False + return True diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index a3d7c59093..c7e62d5558 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -1,30 +1,37 @@ # mypy: ignore-errors from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from typing import TYPE_CHECKING, Any import dask +from dask._task_spec import GraphNode, Task, TaskRef from dask.base import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer from dask.tokenize import tokenize +from dask.typing import Key from distributed.shuffle._arrow import check_minimal_arrow_version from distributed.shuffle._core import ( + P2PBarrierTask, ShuffleId, barrier_key, get_worker_plugin, p2p_barrier, ) -from distributed.shuffle._shuffle import shuffle_transfer +from distributed.shuffle._shuffle import DataFrameShuffleSpec, shuffle_transfer if TYPE_CHECKING: import pandas as pd from pandas._typing import IndexLabel, MergeHow, Suffixes + # TODO import from typing (requires Python >=3.10) + from typing_extensions import TypeAlias + from dask.dataframe.core import _Frame +_T_LowLevelGraph: TypeAlias = dict[Key, GraphNode] _HASH_COLUMN_NAME = "__hash_partition" @@ -148,21 +155,11 @@ def merge_transfer( input: pd.DataFrame, id: ShuffleId, input_partition: int, - npartitions: int, - meta: pd.DataFrame, - parts_out: set[int], - disk: bool, ): return shuffle_transfer( input=input, id=id, input_partition=input_partition, - npartitions=npartitions, - column=_HASH_COLUMN_NAME, - meta=meta, - parts_out=parts_out, - disk=disk, - drop_column=True, ) @@ -208,7 +205,7 @@ class HashJoinP2PLayer(Layer): suffixes: Suffixes indicator: bool meta_output: pd.DataFrame - parts_out: Sequence[int] + parts_out: set[int] name_input_left: str meta_input_left: pd.DataFrame @@ -241,7 +238,7 @@ def __init__( how: MergeHow = "inner", suffixes: Suffixes = ("_x", "_y"), indicator: bool = False, - parts_out: Sequence | None = None, + parts_out: Iterable[int] | None = None, annotations: dict | None = None, ) -> None: check_minimal_arrow_version() @@ -257,7 +254,10 @@ def __init__( self.suffixes = suffixes self.indicator = indicator self.meta_output = meta_output - self.parts_out = parts_out or list(range(npartitions)) + if parts_out: + self.parts_out = set(parts_out) + else: + self.parts_out = set(range(npartitions)) self.n_partitions_left = n_partitions_left self.n_partitions_right = n_partitions_right self.left_index = left_index @@ -325,7 +325,7 @@ def _dict(self): self._cached_dict = dsk return self._cached_dict - def _cull(self, parts_out: Sequence[str]): + def _cull(self, parts_out: Iterable[int]): return HashJoinP2PLayer( name=self.name, name_input_left=self.name_input_left, @@ -365,7 +365,7 @@ def cull(self, keys: Iterable[str], all_keys: Any) -> tuple[HashJoinP2PLayer, di else: return self, culled_deps - def _construct_graph(self) -> dict[tuple | str, tuple]: + def _construct_graph(self) -> _T_LowLevelGraph: token_left = tokenize( # Include self.name to ensure that shuffle IDs are unique for individual # merge operations. Reusing shuffles between merges is dangerous because of @@ -375,6 +375,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.left_on, self.left_index, ) + shuffle_id_left = ShuffleId(token_left) token_right = tokenize( # Include self.name to ensure that shuffle IDs are unique for individual # merge operations. Reusing shuffles between merges is dangerous because of @@ -384,50 +385,79 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.right_on, self.right_index, ) - dsk: dict[tuple | str, tuple] = {} + shuffle_id_right = ShuffleId(token_right) + dsk: _T_LowLevelGraph = {} name_left = "hash-join-transfer-" + token_left name_right = "hash-join-transfer-" + token_right transfer_keys_left = list() - transfer_keys_right = list() for i in range(self.n_partitions_left): - transfer_keys_left.append((name_left, i)) - dsk[(name_left, i)] = ( + t = Task( + (name_left, i), merge_transfer, - (self.name_input_left, i), - token_left, + TaskRef((self.name_input_left, i)), + shuffle_id_left, i, - self.npartitions, - self.meta_input_left, - self.parts_out, - self.disk, ) + dsk[t.key] = t + transfer_keys_left.append(t.ref()) + + transfer_keys_right = list() for i in range(self.n_partitions_right): - transfer_keys_right.append((name_right, i)) - dsk[(name_right, i)] = ( + t = Task( + (name_right, i), merge_transfer, - (self.name_input_right, i), - token_right, + TaskRef((self.name_input_right, i)), + shuffle_id_right, i, - self.npartitions, - self.meta_input_right, - self.parts_out, - self.disk, ) - - _barrier_key_left = barrier_key(ShuffleId(token_left)) - _barrier_key_right = barrier_key(ShuffleId(token_right)) - dsk[_barrier_key_left] = (p2p_barrier, token_left, transfer_keys_left) - dsk[_barrier_key_right] = (p2p_barrier, token_right, transfer_keys_right) + dsk[t.key] = t + transfer_keys_right.append(t.ref()) + + _barrier_key_left = barrier_key(shuffle_id_left) + barrier_left = P2PBarrierTask( + _barrier_key_left, + p2p_barrier, + token_left, + transfer_keys_left, + spec=DataFrameShuffleSpec( + id=shuffle_id_left, + npartitions=self.npartitions, + column=_HASH_COLUMN_NAME, + meta=self.meta_input_left, + parts_out=self.parts_out, + disk=self.disk, + drop_column=True, + ), + ) + dsk[barrier_left.key] = barrier_left + _barrier_key_right = barrier_key(shuffle_id_right) + barrier_right = P2PBarrierTask( + _barrier_key_right, + p2p_barrier, + token_right, + transfer_keys_right, + spec=DataFrameShuffleSpec( + id=shuffle_id_right, + npartitions=self.npartitions, + column=_HASH_COLUMN_NAME, + meta=self.meta_input_right, + parts_out=self.parts_out, + disk=self.disk, + drop_column=True, + ), + ) + dsk[barrier_right.key] = barrier_right name = self.name for part_out in self.parts_out: - dsk[(name, part_out)] = ( + t = Task( + (name, part_out), merge_unpack, token_left, token_right, part_out, - _barrier_key_left, - _barrier_key_right, + barrier_left.ref(), + barrier_right.ref(), self.how, self.left_on, self.right_on, @@ -437,4 +467,5 @@ def _construct_graph(self) -> dict[tuple | str, tuple]: self.right_index, self.indicator, ) + dsk[t.key] = t return dsk diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index c1c74954b6..354828415b 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -132,9 +132,11 @@ from distributed.metrics import context_meter from distributed.shuffle._core import ( NDIndex, + P2PBarrierTask, ShuffleId, ShuffleRun, ShuffleSpec, + barrier_key, get_worker_plugin, handle_transfer_errors, handle_unpack_errors, @@ -142,7 +144,6 @@ ) from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._pickle import unpickle_bytestream -from distributed.shuffle._shuffle import barrier_key from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof @@ -164,15 +165,12 @@ def rechunk_transfer( input: np.ndarray, id: ShuffleId, input_chunk: NDIndex, - new: ChunkedAxes, - old: ChunkedAxes, - disk: bool, ) -> int: with handle_transfer_errors(id): return get_worker_plugin().add_partition( input, partition_id=input_chunk, - spec=ArrayRechunkSpec(id=id, new=new, old=old, disk=disk), + id=id, ) @@ -815,16 +813,19 @@ def partial_rechunk( key, rechunk_transfer, input_key, - partial_token, + ShuffleId(partial_token), partial_index, - partial_new, - partial_old, - disk, ) transfer_keys.append(t.ref()) - dsk[_barrier_key] = barrier = Task( - _barrier_key, p2p_barrier, partial_token, transfer_keys + dsk[_barrier_key] = barrier = P2PBarrierTask( + _barrier_key, + p2p_barrier, + partial_token, + transfer_keys, + spec=ArrayRechunkSpec( + id=ShuffleId(partial_token), new=partial_new, old=partial_old, disk=disk + ), ) new_partial_offset = tuple(axis.start for axis in ndpartial.new) @@ -835,7 +836,7 @@ def partial_rechunk( dsk[k] = Task( k, rechunk_unpack, - partial_token, + ShuffleId(partial_token), partial_index, barrier.ref(), ) diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index ec15e20bf4..ce20537ee7 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -15,6 +15,7 @@ from distributed.protocol.pickle import dumps from distributed.protocol.serialize import ToPickle from distributed.shuffle._core import ( + P2PBarrierTask, RunSpecMessage, SchedulerShuffleState, ShuffleId, @@ -184,23 +185,29 @@ def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec: state.participating_workers.add(worker) return state.run_spec - def _create(self, spec: ShuffleSpec, key: Key, worker: str) -> ShuffleRunSpec: + def _retrieve_spec(self, shuffle_id: ShuffleId) -> ShuffleSpec: + barrier_task_spec = self.scheduler.tasks[barrier_key(shuffle_id)].run_spec + assert isinstance(barrier_task_spec, P2PBarrierTask) + return barrier_task_spec.spec + + def _create(self, shuffle_id: ShuffleId, key: Key, worker: str) -> ShuffleRunSpec: # FIXME: The current implementation relies on the barrier task to be # known by its name. If the name has been mangled, we cannot guarantee # that the shuffle works as intended and should fail instead. - self._raise_if_barrier_unknown(spec.id) + self._raise_if_barrier_unknown(shuffle_id) self._raise_if_task_not_processing(key) + spec = self._retrieve_spec(shuffle_id) worker_for = self._calculate_worker_for(spec) self._ensure_output_tasks_are_non_rootish(spec) state = spec.create_new_run( worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id ) - self.active_shuffles[spec.id] = state - self._shuffles[spec.id].add(state) + self.active_shuffles[shuffle_id] = state + self._shuffles[shuffle_id].add(state) state.participating_workers.add(worker) logger.warning( "Shuffle %s initialized by task %r executed on worker %s", - spec.id, + shuffle_id, key, worker, ) @@ -208,17 +215,17 @@ def _create(self, spec: ShuffleSpec, key: Key, worker: str) -> ShuffleRunSpec: def get_or_create( self, - spec: ShuffleSpec, + shuffle_id: ShuffleId, key: Key, worker: str, ) -> RunSpecMessage | ErrorMessage: try: - run_spec = self._get(spec.id, worker) + run_spec = self._get(shuffle_id, worker) except P2PConsistencyError as e: return error_message(e) except KeyError: try: - run_spec = self._create(spec, key, worker) + run_spec = self._create(shuffle_id, key, worker) except P2PConsistencyError as e: return error_message(e) return {"status": "OK", "run_spec": ToPickle(run_spec)} diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index e6badc0902..912259e4a0 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -41,6 +41,7 @@ ) from distributed.shuffle._core import ( NDIndex, + P2PBarrierTask, ShuffleId, ShuffleRun, ShuffleSpec, @@ -70,26 +71,12 @@ def shuffle_transfer( input: pd.DataFrame, id: ShuffleId, input_partition: int, - npartitions: int, - column: str, - meta: pd.DataFrame, - parts_out: set[int], - disk: bool, - drop_column: bool, ) -> int: with handle_transfer_errors(id): return get_worker_plugin().add_partition( input, input_partition, - spec=DataFrameShuffleSpec( - id=id, - npartitions=npartitions, - column=column, - meta=meta, - parts_out=parts_out, - disk=disk, - drop_column=drop_column, - ), + id, ) @@ -157,7 +144,7 @@ class P2PShuffleLayer(Layer): name_input: str meta_input: pd.DataFrame disk: bool - parts_out: set[int] + parts_out: tuple[int, ...] drop_column: bool def __init__( @@ -181,9 +168,9 @@ def __init__( self.meta_input = meta_input self.disk = disk if parts_out: - self.parts_out = set(parts_out) + self.parts_out = tuple(parts_out) else: - self.parts_out = set(range(self.npartitions)) + self.parts_out = tuple(range(self.npartitions)) self.npartitions_input = npartitions_input self.drop_column = drop_column super().__init__(annotations=annotations) @@ -268,8 +255,9 @@ def cull( def _construct_graph(self) -> _T_LowLevelGraph: token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out) + shuffle_id = ShuffleId(token) dsk: _T_LowLevelGraph = {} - _barrier_key = barrier_key(ShuffleId(token)) + _barrier_key = barrier_key(shuffle_id) name = "shuffle-transfer-" + token transfer_keys = list() for i in range(self.npartitions_input): @@ -279,17 +267,25 @@ def _construct_graph(self) -> _T_LowLevelGraph: TaskRef((self.name_input, i)), token, i, - self.npartitions, - self.column, - self.meta_input, - self.parts_out, - self.disk, - self.drop_column, ) dsk[t.key] = t transfer_keys.append(t.ref()) - barrier = Task(_barrier_key, p2p_barrier, token, transfer_keys) + barrier = P2PBarrierTask( + _barrier_key, + p2p_barrier, + token, + transfer_keys, + spec=DataFrameShuffleSpec( + id=shuffle_id, + npartitions=self.npartitions, + column=self.column, + meta=self.meta_input, + parts_out=self.parts_out, + disk=self.disk, + drop_column=self.drop_column, + ), + ) dsk[barrier.key] = barrier name = self.name @@ -565,12 +561,16 @@ class DataFrameShuffleSpec(ShuffleSpec[int]): npartitions: int column: str meta: pd.DataFrame - parts_out: set[int] + parts_out: Sequence[int] | int drop_column: bool @property def output_partitions(self) -> Generator[int]: - yield from self.parts_out + parts_out = self.parts_out + if isinstance(parts_out, int): + parts_out = range(parts_out) + + yield from parts_out def pick_worker(self, partition: int, workers: Sequence[str]) -> str: return _get_worker_for_range_sharding(self.npartitions, partition, workers) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 57d939a539..88f65af82a 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -14,14 +14,7 @@ from distributed.core import ErrorMessage, OKMessage, clean_exception, error_message from distributed.diagnostics.plugin import WorkerPlugin -from distributed.protocol.serialize import ToPickle -from distributed.shuffle._core import ( - NDIndex, - ShuffleId, - ShuffleRun, - ShuffleRunSpec, - ShuffleSpec, -) +from distributed.shuffle._core import NDIndex, ShuffleId, ShuffleRun, ShuffleRunSpec from distributed.shuffle._exceptions import P2PConsistencyError, ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors, sync @@ -136,24 +129,21 @@ async def get_with_run_id(self, shuffle_id: ShuffleId, run_id: int) -> ShuffleRu raise shuffle_run._exception return shuffle_run - async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun: + async def get_or_create(self, shuffle_id: ShuffleId, key: Key) -> ShuffleRun: """Get or create a shuffle matching the ID and data spec. Parameters ---------- shuffle_id Unique identifier of the shuffle - type: - Type of the shuffle operation key: Task key triggering the function """ - async with self._refresh_locks[spec.id]: - shuffle_run = self._active_runs.get(spec.id, None) + async with self._refresh_locks[shuffle_id]: + shuffle_run = self._active_runs.get(shuffle_id, None) if shuffle_run is None: shuffle_run = await self._refresh( - shuffle_id=spec.id, - spec=spec, + shuffle_id=shuffle_id, key=key, ) @@ -189,17 +179,16 @@ async def get_most_recent( async def _fetch( self, shuffle_id: ShuffleId, - spec: ShuffleSpec | None = None, key: Key | None = None, ) -> ShuffleRunSpec: - if spec is None: + if key is None: response = await self._plugin.worker.scheduler.shuffle_get( id=shuffle_id, worker=self._plugin.worker.address, ) else: response = await self._plugin.worker.scheduler.shuffle_get_or_create( - spec=ToPickle(spec), + shuffle_id=shuffle_id, key=key, worker=self._plugin.worker.address, ) @@ -222,17 +211,15 @@ async def _refresh( async def _refresh( self, shuffle_id: ShuffleId, - spec: ShuffleSpec, key: Key, ) -> ShuffleRun: ... async def _refresh( self, shuffle_id: ShuffleId, - spec: ShuffleSpec | None = None, key: Key | None = None, ) -> ShuffleRun: - result = await self._fetch(shuffle_id=shuffle_id, spec=spec, key=key) + result = await self._fetch(shuffle_id=shuffle_id, key=key) if self.closed: raise ShuffleClosedError(f"{self} has already been closed") if existing := self._active_runs.get(shuffle_id, None): @@ -355,10 +342,10 @@ def add_partition( self, data: Any, partition_id: int | NDIndex, - spec: ShuffleSpec, + id: ShuffleId, **kwargs: Any, ) -> int: - shuffle_run = self.get_or_create_shuffle(spec) + shuffle_run = self.get_or_create_shuffle(id) return shuffle_run.add_partition( data=data, partition_id=partition_id, @@ -418,13 +405,13 @@ def get_shuffle_run( def get_or_create_shuffle( self, - spec: ShuffleSpec, + shuffle_id: ShuffleId, ) -> ShuffleRun: key = thread_state.key return sync( self.worker.loop, self.shuffle_runs.get_or_create, - spec, + shuffle_id, key, ) diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 0cb4303397..f491ba7edd 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -10,7 +10,7 @@ from dask.typing import Key from distributed import Worker -from distributed.shuffle._core import ShuffleId, ShuffleSpec, id_from_key +from distributed.shuffle._core import ShuffleId, id_from_key from distributed.shuffle._worker_plugin import ShuffleRun, _ShuffleRunManager from distributed.utils_test import gen_cluster @@ -421,12 +421,12 @@ def __init__(self, *args: Any, **kwargs: Any): self.blocking_get_or_create = asyncio.Event() self.block_get_or_create = asyncio.Event() - async def get_or_create(self, spec: ShuffleSpec, key: Key) -> ShuffleRun: - if len(self.seen) >= self.limit and spec.id not in self.seen: + async def get_or_create(self, shuffle_id: ShuffleId, key: Key) -> ShuffleRun: + if len(self.seen) >= self.limit and shuffle_id not in self.seen: self.blocking_get_or_create.set() await self.block_get_or_create.wait() - self.seen.add(spec.id) - return await super().get_or_create(spec, key) + self.seen.add(shuffle_id) + return await super().get_or_create(shuffle_id, key) @mock.patch( diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 17d55c8d4e..701d18bf43 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -2349,7 +2349,7 @@ async def test_fail_fetch_race(c, s, a): assert shuffle_id not in run_manager._active_runs with pytest.raises(RuntimeError, match="Received stale shuffle run"): - await run_manager.get_or_create(spec.spec, "test-key") + await run_manager.get_or_create(shuffle_id, "test-key") assert shuffle_id not in run_manager._active_runs worker_plugin.block_barrier.set()