diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 12add27f2d..baa7957334 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -399,10 +399,15 @@ class WorkerState: .. attribute:: processing: {TaskState: cost} A dictionary of tasks that have been submitted to this worker. - Each task state is asssociated with the expected cost in seconds + Each task state is associated with the expected cost in seconds of running that task, summing both the task's expected computation time and the expected communication time of its result. + If a task is already executing on the worker and the excecution time is + twice the learned average TaskGroup duration, this will be set to twice + the current executing time. If the task is unknown, the default task + duration is used instead of the TaskGroup average. + Multiple tasks may be submitted to a worker in advance and the worker will run them eventually, depending on its execution resources (but see :doc:`work-stealing`). @@ -900,6 +905,18 @@ def name(self) -> str: def all_durations(self) -> "defaultdict[str, float]": return self._all_durations + @ccall + @exceptval(check=False) + def add_duration(self, action: str, start: double, stop: double): + duration = stop - start + self._all_durations[action] += duration + if action == "compute": + old = self._duration_average + if old < 0: + self._duration_average = duration + else: + self._duration_average = 0.5 * duration + 0.5 * old + @property def duration_average(self) -> double: return self._duration_average @@ -1062,6 +1079,18 @@ def nbytes_total(self): def duration(self) -> double: return self._duration + @ccall + @exceptval(check=False) + def add_duration(self, action: str, start: double, stop: double): + duration = stop - start + self._all_durations[action] += duration + if action == "compute": + if self._stop < stop: + self._stop = stop + self._start = self._start or start + self._duration += duration + self._prefix.add_duration(action, start, stop) + @property def types(self) -> set: return self._types @@ -2568,6 +2597,8 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: If a task takes longer than twice the current average duration we estimate the task duration to be 2x current-runtime, otherwise we set it to be the average duration. + + See also ``_remove_from_processing`` """ exec_time: double = ws._executing.get(ts, 0) duration: double = self.get_task_duration(ts) @@ -2577,7 +2608,11 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: else: comm: double = self.get_comm_cost(ts, ws) total_duration = duration + comm + old = ws._processing.get(ts, 0) ws._processing[ts] = total_duration + self._total_occupancy += total_duration - old + ws._occupancy += total_duration - old + return total_duration def transition_waiting_processing(self, key): @@ -2602,10 +2637,8 @@ def transition_waiting_processing(self, key): return recommendations, client_msgs, worker_msgs worker = ws._address - duration_estimate = self.set_duration_estimate(ts, ws) + self.set_duration_estimate(ts, ws) ts._processing_on = ws - ws._occupancy += duration_estimate - self._total_occupancy += duration_estimate ts.state = "processing" self.consume_resources(ts, ws) self.check_idle_saturated(ws) @@ -2684,7 +2717,6 @@ def transition_processing_memory( worker_msgs: dict = {} try: ts: TaskState = self._tasks[key] - tg: TaskGroup = ts._group assert worker assert isinstance(worker, str) @@ -2719,57 +2751,26 @@ def transition_processing_memory( } ] - has_compute_startstop: bool = False - compute_start: double - compute_stop: double + ############################# + # Update Timing Information # + ############################# if startstops: startstop: dict for startstop in startstops: - stop = startstop["stop"] - start = startstop["start"] - action = startstop["action"] - if not has_compute_startstop and action == "compute": - compute_start = start - compute_stop = stop - has_compute_startstop = True - - # record timings of all actions -- a cheaper way of - # getting timing info compared with get_task_stream() - ts._prefix._all_durations[action] += stop - start - tg._all_durations[action] += stop - start + ts._group.add_duration( + stop=startstop["stop"], + start=startstop["start"], + action=startstop["action"], + ) - ############################# - # Update Timing Information # - ############################# - if has_compute_startstop and ws._processing.get(ts, True): - # Update average task duration for worker - old_duration: double = ts._prefix._duration_average - new_duration: double = compute_stop - compute_start - avg_duration: double - if old_duration < 0: - avg_duration = new_duration - else: - avg_duration = 0.5 * old_duration + 0.5 * new_duration - - ts._prefix._duration_average = avg_duration - tg._duration += new_duration - tg._start = tg._start or compute_start - if tg._stop < compute_stop: - tg._stop = compute_stop - - s: set = self._unknown_durations.pop(ts._prefix._name, None) - tts: TaskState - if s: - for tts in s: - if tts._processing_on is not None: - wws = tts._processing_on - comm: double = self.get_comm_cost(tts, wws) - old: double = wws._processing[tts] - new: double = avg_duration + comm - diff: double = new - old - wws._processing[tts] = new - wws._occupancy += diff - self._total_occupancy += diff + s: set = self._unknown_durations.pop(ts._prefix._name, set()) + tts: TaskState + for tts in s: + if tts._processing_on: + self.set_duration_estimate(tts, tts._processing_on) + steal = self.extensions.get("stealing") + if steal: + steal.put_key_in_stealable(tts) ############################ # Update State Information # @@ -3314,10 +3315,14 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: return nbytes / self._bandwidth @ccall - def get_task_duration(self, ts: TaskState, default: double = -1) -> double: - """ - Get the estimated computation cost of the given task - (not including any communication cost). + def get_task_duration(self, ts: TaskState) -> double: + """Get the estimated computation cost of the given task (not including + any communication cost). + + If no data has been observed, value of + `distributed.scheduler.default-task-durations` are used. If none is set + for this task, `distributed.scheduler.unknown-task-duration` is used + instead. """ duration: double = ts._prefix._duration_average if duration >= 0: @@ -3327,7 +3332,7 @@ def get_task_duration(self, ts: TaskState, default: double = -1) -> double: if s is None: self._unknown_durations[ts._prefix._name] = s = set() s.add(ts) - return default if default >= 0 else self.UNKNOWN_TASK_DURATION + return self.UNKNOWN_TASK_DURATION @ccall @exceptval(check=False) @@ -7586,7 +7591,6 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): try: if self.status == Status.closed: return - last = time() next_time = timedelta(seconds=0.1) @@ -7713,6 +7717,8 @@ def _remove_from_processing( ) -> str: # -> str | None """ Remove *ts* from the set of processing tasks. + + See also ``Scheduler.set_duration_estimate`` """ ws: WorkerState = ts._processing_on ts._processing_on = None # type: ignore @@ -7870,6 +7876,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> ws: WorkerState dts: TaskState + # FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this if duration < 0: duration = state.get_task_duration(ts) @@ -7942,27 +7949,19 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: @exceptval(check=False) def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): """See reevaluate_occupancy""" - old: double = ws._occupancy - new: double = 0 - diff: double ts: TaskState - est: double + old = ws._occupancy for ts in ws._processing: - est = state.set_duration_estimate(ts, ws) - new += est + state.set_duration_estimate(ts, ws) - ws._occupancy = new - diff = new - old - state._total_occupancy += diff state.check_idle_saturated(ws) - - # significant increase in duration - if new > old * 1.3: - steal = state._extensions.get("stealing") - if steal is not None: - for ts in ws._processing: - steal.remove_key_from_stealable(ts) - steal.put_key_in_stealable(ts) + steal = state.extensions.get("stealing") + if not steal: + return + if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3: + for ts in ws._processing: + steal.remove_key_from_stealable(ts) + steal.put_key_in_stealable(ts) @cfunc diff --git a/distributed/stealing.py b/distributed/stealing.py index 0297691f02..bacd62715e 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -114,13 +114,13 @@ def steal_time_ratio(self, ts): For example a result of zero implies a task without dependencies. level: The location within a stealable list to place this value """ - if not ts.dependencies: # no dependencies fast path - return 0, 0 - split = ts.prefix.name - if split in fast_tasks: + if split in fast_tasks or split in self.scheduler.unknown_durations: return None, None + if not ts.dependencies: # no dependencies fast path + return 0, 0 + ws = ts.processing_on compute_time = ws.processing[ts] if compute_time < 0.005: # 5ms, just give up diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3517805451..70f51b7e75 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -110,7 +110,6 @@ async def test_worksteal_many_thieves(c, s, *workers): assert sum(map(len, s.has_what.values())) < 150 -@pytest.mark.flaky(reruns=10, reruns_delay=5, reason="GH#3574") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) async def test_dont_steal_unknown_functions(c, s, a, b): futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True) @@ -118,12 +117,17 @@ async def test_dont_steal_unknown_functions(c, s, a, b): assert len(a.data) >= 95, [len(a.data), len(b.data)] -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 2, + config={"distributed.scheduler.work-stealing-interval": "10ms"}, +) async def test_eventually_steal_unknown_functions(c, s, a, b): futures = c.map( slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True ) await wait(futures) + assert not s.unknown_durations assert len(a.data) >= 3, [len(a.data), len(b.data)] assert len(b.data) >= 3, [len(a.data), len(b.data)] @@ -597,11 +601,15 @@ async def test(*args, **kwargs): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=60) async def test_restart(c, s, a, b): futures = c.map( - slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True + slowinc, range(100), delay=0.01, workers=a.address, allow_other_workers=True ) while not s.processing[b.worker_address]: await asyncio.sleep(0.01) + # Unknown tasks are never stolen therefore wait for a measurement + while not any(s.tasks[f.key].state == "memory" for f in futures): + await asyncio.sleep(0.01) + steal = s.extensions["stealing"] assert any(st for st in steal.stealable_all) assert any(x for L in steal.stealable.values() for x in L) @@ -821,9 +829,13 @@ async def test_balance_with_longer_task(c, s, a, b): slowinc, 1, delay=5, workers=[a.address], priority=1 ) # a surprisingly long task z = c.submit( - inc, x, workers=[a.address], allow_other_workers=True, priority=0 + slowadd, x, 1, workers=[a.address], allow_other_workers=True, priority=0 ) # a task after y, suggesting a, but open to b + + # Allow task to be learned, otherwise it will not be stolen + _ = c.submit(slowadd, x, 2, workers=[b.address]) await z + assert not y.done() assert z.key in b.data