From 45dc795dbc5d21b1abcbd1258cf3d927c8b5f6f7 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 25 May 2022 11:12:14 +0100 Subject: [PATCH] Refactor find_missing and refresh_who_has --- distributed/scheduler.py | 24 ++++++-- distributed/worker.py | 94 +++++++++++++++-------------- distributed/worker_state_machine.py | 39 ++++++++++++ 3 files changed, 109 insertions(+), 48 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cb5318a8231..d222f7f8b94 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3024,6 +3024,7 @@ def __init__( "keep-alive": lambda *args, **kwargs: None, "log-event": self.log_worker_event, "worker-status-change": self.handle_worker_status_change, + "request-refresh-who-has": self.handle_request_refresh_who_has, } client_handlers = { @@ -4766,6 +4767,21 @@ def handle_worker_status_change( else: self.running.discard(ws) + async def handle_request_refresh_who_has( + self, keys: Iterable[str], worker: str, stimulus_id: str + ) -> None: + """Asynchronous request (through bulk comms) from a Worker to refresh the + who_has for some keys. Not to be confused with scheduler.who_has, which is a + synchronous RPC request from a Client. + """ + self.stream_comms[worker].send( + { + "op": "refresh-who-has", + "who_has": self.get_who_has(keys), + "stimulus_id": stimulus_id, + }, + ) + async def handle_worker(self, comm=None, worker=None, stimulus_id=None): """ Listen to responses from a single worker @@ -6214,13 +6230,13 @@ def get_processing(self, workers=None): w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() } - def get_who_has(self, keys=None): + def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]: if keys is not None: return { - k: [ws.address for ws in self.tasks[k].who_has] - if k in self.tasks + key: [ws.address for ws in self.tasks[key].who_has] + if key in self.tasks else [] - for k in keys + for key in keys } else: return { diff --git a/distributed/worker.py b/distributed/worker.py index 1accfa5f2dd..d8eb0645d75 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -115,6 +115,7 @@ Execute, ExecuteFailureEvent, ExecuteSuccessEvent, + FindMissingEvent, GatherDep, GatherDepDoneEvent, Instructions, @@ -123,7 +124,9 @@ MissingDataMsg, Recs, RecsInstrs, + RefreshWhoHasEvent, ReleaseWorkerDataMsg, + RequestRefreshWhoHasMsg, RescheduleEvent, RescheduleMsg, SendMessageToScheduler, @@ -812,6 +815,7 @@ def __init__( "compute-task": self.handle_compute_task, "free-keys": self.handle_free_keys, "remove-replicas": self.handle_remove_replicas, + "refresh-who-has": self.handle_refresh_who_has, "steal-request": self.handle_steal_request, "worker-status-change": self.handle_worker_status_change, } @@ -840,9 +844,7 @@ def __init__( ) self.periodic_callbacks["keep-alive"] = pc - # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 - pc = PeriodicCallback(self.find_missing, 1000) # type: ignore - self._find_missing_running = False + pc = PeriodicCallback(self.find_missing, 1000) self.periodic_callbacks["find-missing"] = pc self._address = contact_address @@ -1836,6 +1838,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: return "OK" + def handle_refresh_who_has( + self, who_has: dict[str, list[str]], stimulus_id: str + ) -> None: + self.handle_stimulus( + RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id) + ) + async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): if r in self.total_resources: @@ -2854,7 +2863,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: @log_errors def handle_stimulus(self, stim: StateMachineEvent) -> None: - self.stimulus_log.append(stim.to_loggable(handled=time())) + if not isinstance(stim, FindMissingEvent): + self.stimulus_log.append(stim.to_loggable(handled=time())) recs, instructions = self.handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) @@ -3377,22 +3387,19 @@ def done_event(): ) ) recommendations[ts] = "fetch" - del data, response - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) if refresh_who_has: # All workers that hold known replicas of our tasks are busy. # Try querying the scheduler for unknown ones. - who_has = await retry_operation( - self.scheduler.who_has, keys=refresh_who_has - ) - refresh_stimulus_id = f"gather-dep-busy-{time()}" - recommendations, instructions = self._update_who_has( - who_has, stimulus_id=refresh_stimulus_id + instructions.append( + RequestRefreshWhoHasMsg( + keys=list(refresh_who_has), + stimulus_id=f"gather-dep-busy-{time()}", + ) ) - self.transitions(recommendations, stimulus_id=refresh_stimulus_id) - self._handle_instructions(instructions) + + self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3402,36 +3409,13 @@ def _readd_busy_worker(self, worker: str) -> None: ) @log_errors - async def find_missing(self) -> None: - if self._find_missing_running or not self._missing_dep_flight: - return - try: - self._find_missing_running = True - if self.validate: - for ts in self._missing_dep_flight: - assert not ts.who_has + def find_missing(self) -> None: + self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}")) - stimulus_id = f"find-missing-{time()}" - who_has = await retry_operation( - self.scheduler.who_has, - keys=[ts.key for ts in self._missing_dep_flight], - ) - recommendations, instructions = self._update_who_has( - who_has, stimulus_id=stimulus_id - ) - for ts in self._missing_dep_flight: - if ts.who_has: - assert ts not in recommendations - recommendations[ts] = "fetch" - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) - - finally: - self._find_missing_running = False - # This is quite arbitrary but the heartbeat has scaling implemented - self.periodic_callbacks[ - "find-missing" - ].callback_time = self.periodic_callbacks["heartbeat"].callback_time + # This is quite arbitrary but the heartbeat has scaling implemented + self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[ + "heartbeat" + ].callback_time def _update_who_has( self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str @@ -3509,12 +3493,15 @@ def max2(workers: set[str]) -> list[str]: ts.who_has = workers # currently fetching -> can no longer be fetched -> transition to missing + # currently missing -> opportunity to be fetched -> transition to fetch # any other state -> eventually, possibly, the task may transition to fetch # or missing, at which point the relevant transitions will test who_has that # we just updated. e.g. see the various transitions to fetch, which # instead recommend transitioning to missing if who_has is empty. if not workers and ts.state == "fetch": recs[ts] = "missing" + elif workers and ts.state == "missing": + recs[ts] = "fetch" return recs, instructions @@ -4023,6 +4010,25 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs: assert ts, self.story(ev.key) return {ts: "rescheduled"}, [] + @handle_event.register + def _(self, ev: FindMissingEvent) -> RecsInstrs: + if not self._missing_dep_flight: + return {}, [] + + if self.validate: + for ts in self._missing_dep_flight: + assert not ts.who_has + + smsg = RequestRefreshWhoHasMsg( + keys=[ts.key for ts in self._missing_dep_flight], + stimulus_id=ev.stimulus_id, + ) + return {}, [smsg] + + @handle_event.register + def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs: + return self._update_who_has(ev.who_has, stimulus_id=ev.stimulus_id) + def _prepare_args_for_execution( self, ts: TaskState, args: tuple, kwargs: dict[str, Any] ) -> tuple[tuple, dict[str, Any]]: diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 060028aa24b..2f311cdb482 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -380,6 +380,26 @@ class AddKeysMsg(SendMessageToScheduler): keys: list[str] +@dataclass +class RequestRefreshWhoHasMsg(SendMessageToScheduler): + """Worker -> Scheduler asynchronous request for updated who_has information. + Not to be confused with the scheduler.who_has synchronous RPC call, which is used + by the Client. + + See also + -------- + RefreshWhoHasEvent + distributed.scheduler.Scheduler.request_refresh_who_has + distributed.client.Client.who_has + distributed.scheduler.Scheduler.get_who_has + """ + + op = "request-refresh-who-has" + + __slots__ = ("keys",) + keys: list[str] + + @dataclass class StateMachineEvent: __slots__ = ("stimulus_id", "handled") @@ -533,6 +553,25 @@ class RescheduleEvent(StateMachineEvent): key: str +@dataclass +class FindMissingEvent(StateMachineEvent): + __slots__ = () + + +@dataclass +class RefreshWhoHasEvent(StateMachineEvent): + """Scheduler -> Worker message containing updated who_has information. + + See also + -------- + RequestRefreshWhoHasMsg + """ + + __slots__ = ("who_has",) + # {key: [worker address, ...]} + who_has: dict[str, list[str]] + + if TYPE_CHECKING: # TODO remove quotes (requires Python >=3.9) # TODO get out of TYPE_CHECKING (requires Python >=3.10)