From 9f180590004635d9967ed3e77df74279f1fb2550 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 30 Mar 2022 00:56:12 +0100 Subject: [PATCH] _ensure_computing --- distributed/worker.py | 131 +++++++++++++++++----------- distributed/worker_state_machine.py | 24 +++++ 2 files changed, 103 insertions(+), 52 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index 12e60442a2d..d8dfef89486 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -109,6 +109,7 @@ READY, AddKeysMsg, CancelComputeEvent, + EnsureComputingEvent, Execute, ExecuteFailureEvent, ExecuteSuccessEvent, @@ -128,6 +129,8 @@ TaskState, TaskStateState, UniqueTaskHeap, + UnpauseEvent, + merge_recs_instructions, ) if TYPE_CHECKING: @@ -923,8 +926,7 @@ def status(self, value): ServerNode.status.__set__(self, value) self._send_worker_status_change() if prev_status == Status.paused and value == Status.running: - self.ensure_computing() - self.ensure_communicating() + self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}")) def _send_worker_status_change(self) -> None: if ( @@ -2053,9 +2055,10 @@ def transition_executing_rescheduled( self.available_resources[resource] += quantity self._executing.discard(ts) - recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) - return recs, [smsg] + return merge_recs_instructions( + ({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]), + self._ensure_computing(), + ) def transition_waiting_ready( self, ts: TaskState, *, stimulus_id: str @@ -2148,13 +2151,17 @@ def transition_executing_error( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) - return self.transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, + + return merge_recs_instructions( + self.transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ), + self._ensure_computing(), ) def _transition_from_resumed( @@ -2284,7 +2291,7 @@ def transition_executing_released( # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 ts.state = "cancelled" ts.done = False - return {}, [] + return self._ensure_computing() def transition_long_running_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str @@ -2328,7 +2335,10 @@ def transition_executing_memory( self._executing.discard(ts) self.executed_count += 1 - return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) + return merge_recs_instructions( + self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), + self._ensure_computing(), + ) def transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str @@ -2432,9 +2442,11 @@ def transition_executing_long_running( ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) - self.io_loop.add_callback(self.ensure_computing) - return {}, [smsg] + + return merge_recs_instructions( + ({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]), + self._ensure_computing(), + ) def transition_released_memory( self, ts: TaskState, value, *, stimulus_id: str @@ -2607,8 +2619,6 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None: recs, instructions = self.handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) - self.ensure_computing() - self.ensure_communicating() def _handle_stimulus_from_future( self, future: asyncio.Future[StateMachineEvent | None] @@ -3385,41 +3395,48 @@ async def _maybe_deserialize_task( raise def ensure_computing(self) -> None: + self.handle_stimulus( + EnsureComputingEvent(stimulus_id=f"ensure_computing-{time()}") + ) + + def _ensure_computing(self) -> RecsInstrs: if self.status in (Status.paused, Status.closing_gracefully): - return - try: - stimulus_id = f"ensure-computing-{time()}" - while self.constrained and self.executing_count < self.nthreads: - key = self.constrained[0] - ts = self.tasks.get(key, None) - if ts is None or ts.state != "constrained": - self.constrained.popleft() - continue - if self.meets_resource_constraints(key): - self.constrained.popleft() - self.transition(ts, "executing", stimulus_id=stimulus_id) - else: - break - while self.ready and self.executing_count < self.nthreads: - priority, key = heapq.heappop(self.ready) - ts = self.tasks.get(key) - if ts is None: - # It is possible for tasks to be released while still remaining on - # `ready` The scheduler might have re-routed to a new worker and - # told this worker to release. If the task has "disappeared" just - # continue through the heap - continue - elif ts.key in self.data: - self.transition(ts, "memory", stimulus_id=stimulus_id) - elif ts.state in READY: - self.transition(ts, "executing", stimulus_id=stimulus_id) - except Exception as e: # pragma: no cover - logger.exception(e) - if LOG_PDB: - import pdb + return {}, [] - pdb.set_trace() - raise + recs: Recs = {} + executing_count = self.executing_count + while self.constrained and executing_count < self.nthreads: + key = self.constrained[0] + ts = self.tasks.get(key, None) + if ts is None or ts.state != "constrained": + self.constrained.popleft() + continue + if self.meets_resource_constraints(key): + self.constrained.popleft() + assert ts not in recs + recs[ts] = "executing" + executing_count += 1 + else: + break + + while self.ready and executing_count < self.nthreads: + priority, key = heapq.heappop(self.ready) + ts = self.tasks.get(key) + if ts is None: + # It is possible for tasks to be released while still remaining on + # `ready` The scheduler might have re-routed to a new worker and + # told this worker to release. If the task has "disappeared" just + # continue through the heap + continue + + assert ts not in recs + if ts.key in self.data: + recs[ts] = "memory" + elif ts.state in READY: + recs[ts] = "executing" + executing_count += 1 + + return recs, [] async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: @@ -3549,6 +3566,16 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No def handle_event(self, ev: StateMachineEvent) -> RecsInstrs: raise TypeError(ev) # pragma: nocover + @handle_event.register + def _(self, ev: EnsureComputingEvent) -> RecsInstrs: + return self._ensure_computing() + + @handle_event.register + def _(self, ev: UnpauseEvent) -> RecsInstrs: + assert self.status == Status.running + self.ensure_communicating() + return self._ensure_computing() + @handle_event.register def _(self, ev: CancelComputeEvent) -> RecsInstrs: ts = self.tasks.get(ev.key) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 178f178edad..a4e248f0f42 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -8,6 +8,8 @@ from typing import Collection # TODO move to collections.abc (requires Python >=3.9) from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict +from tlz import concat, merge + import dask from dask.utils import parse_bytes @@ -356,6 +358,21 @@ class StateMachineEvent: stimulus_id: str +@dataclass +class EnsureComputingEvent(StateMachineEvent): + """Let various methods of worker give an artificial 'kick' to _ensure_computing. + This is a temporary hack to be removed as part of + https://github.com/dask/distributed/issues/5894. + """ + + __slots__ = () + + +@dataclass +class UnpauseEvent(StateMachineEvent): + __slots__ = () + + @dataclass class ExecuteSuccessEvent(StateMachineEvent): key: str @@ -403,3 +420,10 @@ class RescheduleEvent(StateMachineEvent): Recs = dict Instructions = list RecsInstrs = tuple + + +def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs: + return ( + merge(e[0] for e in args), + list(concat([e[1] for e in args])), + )