Skip to content

Commit

Permalink
_ensure_computing
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 29, 2022
1 parent e27d281 commit 9f18059
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 52 deletions.
131 changes: 79 additions & 52 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
READY,
AddKeysMsg,
CancelComputeEvent,
EnsureComputingEvent,
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand All @@ -128,6 +129,8 @@
TaskState,
TaskStateState,
UniqueTaskHeap,
UnpauseEvent,
merge_recs_instructions,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -923,8 +926,7 @@ def status(self, value):
ServerNode.status.__set__(self, value)
self._send_worker_status_change()
if prev_status == Status.paused and value == Status.running:
self.ensure_computing()
self.ensure_communicating()
self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}"))

def _send_worker_status_change(self) -> None:
if (
Expand Down Expand Up @@ -2053,9 +2055,10 @@ def transition_executing_rescheduled(
self.available_resources[resource] += quantity
self._executing.discard(ts)

recs: Recs = {ts: "released"}
smsg = RescheduleMsg(key=ts.key, worker=self.address)
return recs, [smsg]
return merge_recs_instructions(
({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]),
self._ensure_computing(),
)

def transition_waiting_ready(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2148,13 +2151,17 @@ def transition_executing_error(
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
self._executing.discard(ts)
return self.transition_generic_error(
ts,
exception,
traceback,
exception_text,
traceback_text,
stimulus_id=stimulus_id,

return merge_recs_instructions(
self.transition_generic_error(
ts,
exception,
traceback,
exception_text,
traceback_text,
stimulus_id=stimulus_id,
),
self._ensure_computing(),
)

def _transition_from_resumed(
Expand Down Expand Up @@ -2284,7 +2291,7 @@ def transition_executing_released(
# See https://github.com/dask/distributed/pull/5046#discussion_r685093940
ts.state = "cancelled"
ts.done = False
return {}, []
return self._ensure_computing()

def transition_long_running_memory(
self, ts: TaskState, value=no_value, *, stimulus_id: str
Expand Down Expand Up @@ -2328,7 +2335,10 @@ def transition_executing_memory(

self._executing.discard(ts)
self.executed_count += 1
return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id)
return merge_recs_instructions(
self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id),
self._ensure_computing(),
)

def transition_constrained_executing(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2432,9 +2442,11 @@ def transition_executing_long_running(
ts.state = "long-running"
self._executing.discard(ts)
self.long_running.add(ts.key)
smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration)
self.io_loop.add_callback(self.ensure_computing)
return {}, [smsg]

return merge_recs_instructions(
({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]),
self._ensure_computing(),
)

def transition_released_memory(
self, ts: TaskState, value, *, stimulus_id: str
Expand Down Expand Up @@ -2607,8 +2619,6 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None:
recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
self.ensure_computing()
self.ensure_communicating()

def _handle_stimulus_from_future(
self, future: asyncio.Future[StateMachineEvent | None]
Expand Down Expand Up @@ -3385,41 +3395,48 @@ async def _maybe_deserialize_task(
raise

def ensure_computing(self) -> None:
self.handle_stimulus(
EnsureComputingEvent(stimulus_id=f"ensure_computing-{time()}")
)

def _ensure_computing(self) -> RecsInstrs:
if self.status in (Status.paused, Status.closing_gracefully):
return
try:
stimulus_id = f"ensure-computing-{time()}"
while self.constrained and self.executing_count < self.nthreads:
key = self.constrained[0]
ts = self.tasks.get(key, None)
if ts is None or ts.state != "constrained":
self.constrained.popleft()
continue
if self.meets_resource_constraints(key):
self.constrained.popleft()
self.transition(ts, "executing", stimulus_id=stimulus_id)
else:
break
while self.ready and self.executing_count < self.nthreads:
priority, key = heapq.heappop(self.ready)
ts = self.tasks.get(key)
if ts is None:
# It is possible for tasks to be released while still remaining on
# `ready` The scheduler might have re-routed to a new worker and
# told this worker to release. If the task has "disappeared" just
# continue through the heap
continue
elif ts.key in self.data:
self.transition(ts, "memory", stimulus_id=stimulus_id)
elif ts.state in READY:
self.transition(ts, "executing", stimulus_id=stimulus_id)
except Exception as e: # pragma: no cover
logger.exception(e)
if LOG_PDB:
import pdb
return {}, []

pdb.set_trace()
raise
recs: Recs = {}
executing_count = self.executing_count
while self.constrained and executing_count < self.nthreads:
key = self.constrained[0]
ts = self.tasks.get(key, None)
if ts is None or ts.state != "constrained":
self.constrained.popleft()
continue
if self.meets_resource_constraints(key):
self.constrained.popleft()
assert ts not in recs
recs[ts] = "executing"
executing_count += 1
else:
break

while self.ready and executing_count < self.nthreads:
priority, key = heapq.heappop(self.ready)
ts = self.tasks.get(key)
if ts is None:
# It is possible for tasks to be released while still remaining on
# `ready` The scheduler might have re-routed to a new worker and
# told this worker to release. If the task has "disappeared" just
# continue through the heap
continue

assert ts not in recs
if ts.key in self.data:
recs[ts] = "memory"
elif ts.state in READY:
recs[ts] = "executing"
executing_count += 1

return recs, []

async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
Expand Down Expand Up @@ -3549,6 +3566,16 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
def handle_event(self, ev: StateMachineEvent) -> RecsInstrs:
raise TypeError(ev) # pragma: nocover

@handle_event.register
def _(self, ev: EnsureComputingEvent) -> RecsInstrs:
return self._ensure_computing()

@handle_event.register
def _(self, ev: UnpauseEvent) -> RecsInstrs:
assert self.status == Status.running
self.ensure_communicating()
return self._ensure_computing()

@handle_event.register
def _(self, ev: CancelComputeEvent) -> RecsInstrs:
ts = self.tasks.get(ev.key)
Expand Down
24 changes: 24 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Collection # TODO move to collections.abc (requires Python >=3.9)
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict

from tlz import concat, merge

import dask
from dask.utils import parse_bytes

Expand Down Expand Up @@ -356,6 +358,21 @@ class StateMachineEvent:
stimulus_id: str


@dataclass
class EnsureComputingEvent(StateMachineEvent):
"""Let various methods of worker give an artificial 'kick' to _ensure_computing.
This is a temporary hack to be removed as part of
https://github.com/dask/distributed/issues/5894.
"""

__slots__ = ()


@dataclass
class UnpauseEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class ExecuteSuccessEvent(StateMachineEvent):
key: str
Expand Down Expand Up @@ -403,3 +420,10 @@ class RescheduleEvent(StateMachineEvent):
Recs = dict
Instructions = list
RecsInstrs = tuple


def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs:
return (
merge(e[0] for e in args),
list(concat([e[1] for e in args])),
)

0 comments on commit 9f18059

Please sign in to comment.