Skip to content

Commit

Permalink
Refactor find_missing and refresh_who_has (#6348)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 1, 2022
1 parent bd11979 commit 6b6c0ed
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 48 deletions.
24 changes: 20 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,7 @@ def __init__(
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
"request-refresh-who-has": self.handle_request_refresh_who_has,
}

client_handlers = {
Expand Down Expand Up @@ -4782,6 +4783,21 @@ def handle_worker_status_change(
else:
self.running.discard(ws)

async def handle_request_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
) -> None:
"""Asynchronous request (through bulk comms) from a Worker to refresh the
who_has for some keys. Not to be confused with scheduler.who_has, which is a
synchronous RPC request from a Client.
"""
self.stream_comms[worker].send(
{
"op": "refresh-who-has",
"who_has": self.get_who_has(keys),
"stimulus_id": stimulus_id,
},
)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
"""
Listen to responses from a single worker
Expand Down Expand Up @@ -6230,13 +6246,13 @@ def get_processing(self, workers=None):
w: [ts.key for ts in ws.processing] for w, ws in self.workers.items()
}

def get_who_has(self, keys=None):
def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]:
if keys is not None:
return {
k: [ws.address for ws in self.tasks[k].who_has]
if k in self.tasks
key: [ws.address for ws in self.tasks[key].who_has]
if key in self.tasks
else []
for k in keys
for key in keys
}
else:
return {
Expand Down
118 changes: 74 additions & 44 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
GatherDep,
GatherDepDoneEvent,
Instructions,
Expand All @@ -123,7 +124,9 @@
MissingDataMsg,
Recs,
RecsInstrs,
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
RequestRefreshWhoHasMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
Expand Down Expand Up @@ -813,6 +816,7 @@ def __init__(
"free-keys": self.handle_free_keys,
"remove-replicas": self.handle_remove_replicas,
"steal-request": self.handle_steal_request,
"refresh-who-has": self.handle_refresh_who_has,
"worker-status-change": self.handle_worker_status_change,
}

Expand Down Expand Up @@ -840,9 +844,7 @@ def __init__(
)
self.periodic_callbacks["keep-alive"] = pc

# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
pc = PeriodicCallback(self.find_missing, 1000) # type: ignore
self._find_missing_running = False
pc = PeriodicCallback(self.find_missing, 1000)
self.periodic_callbacks["find-missing"] = pc

self._address = contact_address
Expand Down Expand Up @@ -1839,6 +1841,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str:

return "OK"

def handle_refresh_who_has(
self, who_has: dict[str, list[str]], stimulus_id: str
) -> None:
self.handle_stimulus(
RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id)
)

async def set_resources(self, **resources) -> None:
for r, quantity in resources.items():
if r in self.total_resources:
Expand Down Expand Up @@ -2849,7 +2858,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:

@log_errors
def handle_stimulus(self, stim: StateMachineEvent) -> None:
self.stimulus_log.append(stim.to_loggable(handled=time()))
if not isinstance(stim, FindMissingEvent):
self.stimulus_log.append(stim.to_loggable(handled=time()))
recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
Expand Down Expand Up @@ -2991,11 +3001,8 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
if ts.state != "fetch" or ts.key in all_keys_to_gather:
continue

if not ts.who_has:
recommendations[ts] = "missing"
continue

if self.validate:
assert ts.who_has
assert self.address not in ts.who_has

workers = [
Expand Down Expand Up @@ -3348,7 +3355,7 @@ def done_event():
self.busy_workers.add(worker)
self.io_loop.call_later(0.15, self._readd_busy_worker, worker)

refresh_who_has = set()
refresh_who_has = []

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
Expand All @@ -3358,7 +3365,7 @@ def done_event():
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.add(ts.key)
refresh_who_has.append(d)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
Expand All @@ -3371,17 +3378,19 @@ def done_event():
)
)
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
who_has = await retry_operation(
self.scheduler.who_has, keys=refresh_who_has
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
)
self._update_who_has(who_has)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
Expand All @@ -3391,33 +3400,13 @@ def _readd_busy_worker(self, worker: str) -> None:
)

@log_errors
async def find_missing(self) -> None:
if self._find_missing_running or not self._missing_dep_flight:
return
try:
self._find_missing_running = True
if self.validate:
for ts in self._missing_dep_flight:
assert not ts.who_has
def find_missing(self) -> None:
self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}"))

stimulus_id = f"find-missing-{time()}"
who_has = await retry_operation(
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
self._update_who_has(who_has)
recommendations: Recs = {}
for ts in self._missing_dep_flight:
if ts.who_has:
recommendations[ts] = "fetch"
self.transitions(recommendations, stimulus_id=stimulus_id)

finally:
self._find_missing_running = False
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time
# This is quite arbitrary but the heartbeat has scaling implemented
self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[
"heartbeat"
].callback_time

def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None:
for key, workers in who_has.items():
Expand Down Expand Up @@ -3965,6 +3954,47 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs:
assert ts, self.story(ev.key)
return {ts: "rescheduled"}, []

@handle_event.register
def _(self, ev: FindMissingEvent) -> RecsInstrs:
if not self._missing_dep_flight:
return {}, []

if self.validate:
assert not any(ts.who_has for ts in self._missing_dep_flight)

smsg = RequestRefreshWhoHasMsg(
keys=[ts.key for ts in self._missing_dep_flight],
stimulus_id=ev.stimulus_id,
)
return {}, [smsg]

@handle_event.register
def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs:
self._update_who_has(ev.who_has)
recommendations: Recs = {}
instructions: Instructions = []

for key in ev.who_has:
ts = self.tasks.get(key)
if not ts:
continue

if ts.who_has and ts.state == "missing":
recommendations[ts] = "fetch"
elif ts.who_has and ts.state == "fetch":
# We potentially just acquired new replicas whereas all previously known
# workers are in flight or busy. We're deliberately not testing the
# minute use cases here for the sake of simplicity; instead we rely on
# _ensure_communicating to be a no-op when there's nothing to do.
recommendations, instructions = merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
elif not ts.who_has and ts.state == "fetch":
recommendations[ts] = "missing"

return recommendations, instructions

def _prepare_args_for_execution(
self, ts: TaskState, args: tuple, kwargs: dict[str, Any]
) -> tuple[tuple, dict[str, Any]]:
Expand Down Expand Up @@ -4190,8 +4220,8 @@ def validate_task_fetch(self, ts):
assert self.address not in ts.who_has
assert not ts.done
assert ts in self.data_needed
# Note: ts.who_has may be have been emptied by _update_who_has, but the task
# won't transition to missing until it reaches the top of the data_needed heap.
assert ts.who_has

for w in ts.who_has:
assert ts.key in self.has_what[w]
assert ts in self.data_needed_per_worker[w]
Expand Down
39 changes: 39 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,26 @@ class AddKeysMsg(SendMessageToScheduler):
keys: list[str]


@dataclass
class RequestRefreshWhoHasMsg(SendMessageToScheduler):
"""Worker -> Scheduler asynchronous request for updated who_has information.
Not to be confused with the scheduler.who_has synchronous RPC call, which is used
by the Client.
See also
--------
RefreshWhoHasEvent
distributed.scheduler.Scheduler.request_refresh_who_has
distributed.client.Client.who_has
distributed.scheduler.Scheduler.get_who_has
"""

op = "request-refresh-who-has"

__slots__ = ("keys",)
keys: list[str]


@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id", "handled")
Expand Down Expand Up @@ -533,6 +553,25 @@ class RescheduleEvent(StateMachineEvent):
key: str


@dataclass
class FindMissingEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class RefreshWhoHasEvent(StateMachineEvent):
"""Scheduler -> Worker message containing updated who_has information.
See also
--------
RequestRefreshWhoHasMsg
"""

__slots__ = ("who_has",)
# {key: [worker address, ...]}
who_has: dict[str, list[str]]


if TYPE_CHECKING:
# TODO remove quotes (requires Python >=3.9)
# TODO get out of TYPE_CHECKING (requires Python >=3.10)
Expand Down

0 comments on commit 6b6c0ed

Please sign in to comment.