Skip to content

Commit

Permalink
_ensure_computing
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 30, 2022
1 parent b15d5dc commit 9d28bde
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 62 deletions.
16 changes: 16 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from distributed.worker_state_machine import (
Instruction,
ReleaseWorkerDataMsg,
RescheduleMsg,
SendMessageToScheduler,
StateMachineEvent,
TaskState,
UniqueTaskHeap,
merge_recs_instructions,
)


Expand Down Expand Up @@ -115,3 +117,17 @@ def test_sendmsg_to_dict():
def test_event_slots(cls):
smsg = cls(**dict.fromkeys(cls.__annotations__), stimulus_id="test")
assert not hasattr(smsg, "__dict__")


def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
instr1 = RescheduleMsg(key="foo", worker="a")
instr2 = RescheduleMsg(key="bar", worker="b")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
) == (
{x: "memory", y: "released"},
[instr1, instr2],
)
166 changes: 104 additions & 62 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
AddKeysMsg,
AlreadyCancelledEvent,
CancelComputeEvent,
EnsureComputingEvent,
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand All @@ -129,6 +130,8 @@
TaskState,
TaskStateState,
UniqueTaskHeap,
UnpauseEvent,
merge_recs_instructions,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -924,8 +927,7 @@ def status(self, value):
ServerNode.status.__set__(self, value)
self._send_worker_status_change()
if prev_status == Status.paused and value == Status.running:
self.ensure_computing()
self.ensure_communicating()
self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}"))

def _send_worker_status_change(self) -> None:
if (
Expand Down Expand Up @@ -2054,9 +2056,10 @@ def transition_executing_rescheduled(
self.available_resources[resource] += quantity
self._executing.discard(ts)

recs: Recs = {ts: "released"}
smsg = RescheduleMsg(key=ts.key, worker=self.address)
return recs, [smsg]
return merge_recs_instructions(
({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]),
self._ensure_computing(),
)

def transition_waiting_ready(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2149,13 +2152,17 @@ def transition_executing_error(
for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
self._executing.discard(ts)
return self.transition_generic_error(
ts,
exception,
traceback,
exception_text,
traceback_text,
stimulus_id=stimulus_id,

return merge_recs_instructions(
self.transition_generic_error(
ts,
exception,
traceback,
exception_text,
traceback_text,
stimulus_id=stimulus_id,
),
self._ensure_computing(),
)

def _transition_from_resumed(
Expand Down Expand Up @@ -2270,12 +2277,12 @@ def transition_cancelled_released(

for resource, quantity in ts.resource_restrictions.items():
self.available_resources[resource] += quantity
recs, instructions = self.transition_generic_released(
ts, stimulus_id=stimulus_id

return merge_recs_instructions(
self.transition_generic_released(ts, stimulus_id=stimulus_id),
({ts: next_state} if next_state != "released" else {}, []),
self._ensure_computing(),
)
if next_state != "released":
recs[ts] = next_state
return recs, instructions

def transition_executing_released(
self, ts: TaskState, *, stimulus_id: str
Expand All @@ -2285,7 +2292,7 @@ def transition_executing_released(
# See https://github.com/dask/distributed/pull/5046#discussion_r685093940
ts.state = "cancelled"
ts.done = False
return {}, []
return self._ensure_computing()

def transition_long_running_memory(
self, ts: TaskState, value=no_value, *, stimulus_id: str
Expand All @@ -2308,16 +2315,22 @@ def transition_generic_memory(
self._executing.discard(ts)
self._in_flight_tasks.discard(ts)
ts.coming_from = None

instructions: Instructions = []
try:
recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id)
except Exception as e:
msg = error_message(e)
recs = {ts: tuple(msg.values())}
return recs, []
if self.validate:
assert ts.key in self.data or ts.key in self.actors
smsg = self._get_task_finished_msg(ts)
return recs, [smsg]
else:
if self.validate:
assert ts.key in self.data or ts.key in self.actors
instructions.append(self._get_task_finished_msg(ts))

return merge_recs_instructions(
(recs, instructions),
self._ensure_computing(),
)

def transition_executing_memory(
self, ts: TaskState, value=no_value, *, stimulus_id: str
Expand All @@ -2329,7 +2342,10 @@ def transition_executing_memory(

self._executing.discard(ts)
self.executed_count += 1
return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id)
return merge_recs_instructions(
self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id),
self._ensure_computing(),
)

def transition_constrained_executing(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2433,9 +2449,11 @@ def transition_executing_long_running(
ts.state = "long-running"
self._executing.discard(ts)
self.long_running.add(ts.key)
smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration)
self.io_loop.add_callback(self.ensure_computing)
return {}, [smsg]

return merge_recs_instructions(
({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]),
self._ensure_computing(),
)

def transition_released_memory(
self, ts: TaskState, value, *, stimulus_id: str
Expand Down Expand Up @@ -2608,8 +2626,6 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None:
recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
self.ensure_computing()
self.ensure_communicating()

def _handle_stimulus_from_future(
self, future: asyncio.Future[StateMachineEvent | None]
Expand Down Expand Up @@ -3225,6 +3241,8 @@ def release_key(

self._executing.discard(ts)
self._in_flight_tasks.discard(ts)
self.ensure_computing()
self.ensure_communicating()

self._notify_plugins(
"release_key", key, state_before, cause, stimulus_id, report
Expand Down Expand Up @@ -3386,41 +3404,48 @@ async def _maybe_deserialize_task(
raise

def ensure_computing(self) -> None:
self.handle_stimulus(
EnsureComputingEvent(stimulus_id=f"ensure_computing-{time()}")
)

def _ensure_computing(self) -> RecsInstrs:
if self.status in (Status.paused, Status.closing_gracefully):
return
try:
stimulus_id = f"ensure-computing-{time()}"
while self.constrained and self.executing_count < self.nthreads:
key = self.constrained[0]
ts = self.tasks.get(key, None)
if ts is None or ts.state != "constrained":
self.constrained.popleft()
continue
if self.meets_resource_constraints(key):
self.constrained.popleft()
self.transition(ts, "executing", stimulus_id=stimulus_id)
else:
break
while self.ready and self.executing_count < self.nthreads:
priority, key = heapq.heappop(self.ready)
ts = self.tasks.get(key)
if ts is None:
# It is possible for tasks to be released while still remaining on
# `ready` The scheduler might have re-routed to a new worker and
# told this worker to release. If the task has "disappeared" just
# continue through the heap
continue
elif ts.key in self.data:
self.transition(ts, "memory", stimulus_id=stimulus_id)
elif ts.state in READY:
self.transition(ts, "executing", stimulus_id=stimulus_id)
except Exception as e: # pragma: no cover
logger.exception(e)
if LOG_PDB:
import pdb
return {}, []

pdb.set_trace()
raise
recs: Recs = {}
executing_count = self.executing_count
while self.constrained and executing_count < self.nthreads:
key = self.constrained[0]
ts = self.tasks.get(key, None)
if ts is None or ts.state != "constrained":
self.constrained.popleft()
continue
if self.meets_resource_constraints(key):
self.constrained.popleft()
assert ts not in recs
recs[ts] = "executing"
executing_count += 1
else:
break

while self.ready and executing_count < self.nthreads:
priority, key = heapq.heappop(self.ready)
ts = self.tasks.get(key)
if ts is None:
# It is possible for tasks to be released while still remaining on
# `ready` The scheduler might have re-routed to a new worker and
# told this worker to release. If the task has "disappeared" just
# continue through the heap
continue

assert ts not in recs
if ts.key in self.data:
recs[ts] = "memory"
elif ts.state in READY:
recs[ts] = "executing"
executing_count += 1

return recs, []

async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
Expand Down Expand Up @@ -3554,6 +3579,23 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
def handle_event(self, ev: StateMachineEvent) -> RecsInstrs:
raise TypeError(ev) # pragma: nocover

@handle_event.register
def _(self, ev: EnsureComputingEvent) -> RecsInstrs:
"""Let various methods of worker give an artificial 'kick' to _ensure_computing.
This is a temporary hack to be removed as part of
https://github.com/dask/distributed/issues/5894.
"""
return self._ensure_computing()

@handle_event.register
def _(self, ev: UnpauseEvent) -> RecsInstrs:
"""Emerge from paused status. Do not send this event directly. Instead, just set
Worker.status back to running.
"""
assert self.status == Status.running
self.ensure_communicating()
return self._ensure_computing()

@handle_event.register
def _(self, ev: CancelComputeEvent) -> RecsInstrs:
"""Scheduler requested to cancel a task"""
Expand Down
24 changes: 24 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Collection # TODO move to collections.abc (requires Python >=3.9)
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict

from tlz import concat, merge

import dask
from dask.utils import parse_bytes

Expand Down Expand Up @@ -356,6 +358,21 @@ class StateMachineEvent:
stimulus_id: str


@dataclass
class EnsureComputingEvent(StateMachineEvent):
"""Let various methods of worker give an artificial 'kick' to _ensure_computing.
This is a temporary hack to be removed as part of
https://github.com/dask/distributed/issues/5894.
"""

__slots__ = ()


@dataclass
class UnpauseEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class ExecuteSuccessEvent(StateMachineEvent):
key: str
Expand Down Expand Up @@ -409,3 +426,10 @@ class RescheduleEvent(StateMachineEvent):
Recs = dict
Instructions = list
RecsInstrs = tuple


def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs:
return (
merge(e[0] for e in args),
list(concat([e[1] for e in args])),
)

0 comments on commit 9d28bde

Please sign in to comment.