diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 6663ebacd9..d89483472d 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -2162,8 +2162,8 @@ def __init__(self, scheduler, **kwargs): node_colors = factor_cmap( "state", - factors=["waiting", "processing", "memory", "released", "erred"], - palette=["gray", "green", "red", "blue", "black"], + factors=["waiting", "queued", "processing", "memory", "released", "erred"], + palette=["gray", "yellow", "green", "red", "blue", "black"], ) self.root = figure(title="Task Graph", **kwargs) @@ -3051,7 +3051,7 @@ def __init__(self, scheduler, **kwargs): self.scheduler = scheduler data = progress_quads( - dict(all={}, memory={}, erred={}, released={}, processing={}) + dict(all={}, memory={}, erred={}, released={}, processing={}, queued={}) ) self.source = ColumnDataSource(data=data) @@ -3123,6 +3123,18 @@ def __init__(self, scheduler, **kwargs): fill_alpha=0.35, line_alpha=0, ) + self.root.quad( + source=self.source, + top="top", + bottom="bottom", + left="processing-loc", + right="queued-loc", + fill_color="gray", + hatch_pattern="/", + hatch_color="white", + fill_alpha=0.35, + line_alpha=0, + ) self.root.text( source=self.source, text="show-name", @@ -3158,6 +3170,14 @@ def __init__(self, scheduler, **kwargs): All:  @all +
+ Queued:  + @queued +
+
+ Processing:  + @processing +
Memory:  @memory @@ -3166,10 +3186,6 @@ def __init__(self, scheduler, **kwargs): Erred:  @erred
-
- Ready:  - @processing -
""", ) self.root.add_tools(hover) @@ -3183,6 +3199,7 @@ def update(self): "released": {}, "processing": {}, "waiting": {}, + "queued": {}, } for tp in self.scheduler.task_prefixes.values(): @@ -3193,6 +3210,7 @@ def update(self): state["released"][tp.name] = active_states["released"] state["processing"][tp.name] = active_states["processing"] state["waiting"][tp.name] = active_states["waiting"] + state["queued"][tp.name] = active_states["queued"] state["all"] = {k: sum(v[k] for v in state.values()) for k in state["memory"]} @@ -3205,7 +3223,7 @@ def update(self): totals = { k: sum(state[k].values()) - for k in ["all", "memory", "erred", "released", "waiting"] + for k in ["all", "memory", "erred", "released", "waiting", "queued"] } totals["processing"] = totals["all"] - sum( v for k, v in totals.items() if k != "all" @@ -3213,8 +3231,10 @@ def update(self): self.root.title.text = ( "Progress -- total: %(all)s, " - "in-memory: %(memory)s, processing: %(processing)s, " "waiting: %(waiting)s, " + "queued: %(queued)s, " + "processing: %(processing)s, " + "in-memory: %(memory)s, " "erred: %(erred)s" % totals ) diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py index 03aedde8f6..764aef790e 100644 --- a/distributed/diagnostics/progress_stream.py +++ b/distributed/diagnostics/progress_stream.py @@ -17,7 +17,7 @@ def counts(scheduler, allprogress): {"all": valmap(len, allprogress.all), "nbytes": allprogress.nbytes}, { state: valmap(len, allprogress.state[state]) - for state in ["memory", "erred", "released", "processing"] + for state in ["memory", "erred", "released", "processing", "queued"] }, ) @@ -66,23 +66,29 @@ def progress_quads(msg, nrows=8, ncols=3): ... 'memory': {'inc': 2, 'dec': 0, 'add': 1}, ... 'erred': {'inc': 0, 'dec': 1, 'add': 0}, ... 'released': {'inc': 1, 'dec': 0, 'add': 1}, - ... 'processing': {'inc': 1, 'dec': 0, 'add': 2}} + ... 'processing': {'inc': 1, 'dec': 0, 'add': 2}, + ... 'queued': {'inc': 1, 'dec': 0, 'add': 2}} >>> progress_quads(msg, nrows=2) # doctest: +SKIP - {'name': ['inc', 'add', 'dec'], - 'left': [0, 0, 1], - 'right': [0.9, 0.9, 1.9], - 'top': [0, -1, 0], - 'bottom': [-.8, -1.8, -.8], - 'released': [1, 1, 0], - 'memory': [2, 1, 0], - 'erred': [0, 0, 1], - 'processing': [1, 0, 2], - 'done': ['3 / 5', '2 / 4', '1 / 1'], - 'released-loc': [.2/.9, .25 / 0.9, 1], - 'memory-loc': [3 / 5 / .9, .5 / 0.9, 1], - 'erred-loc': [3 / 5 / .9, .5 / 0.9, 1.9], - 'processing-loc': [4 / 5, 1 / 1, 1]}} + {'all': [5, 4, 1], + 'memory': [2, 1, 0], + 'erred': [0, 0, 1], + 'released': [1, 1, 0], + 'processing': [1, 2, 0], + 'queued': [1, 2, 0], + 'name': ['inc', 'add', 'dec'], + 'show-name': ['inc', 'add', 'dec'], + 'left': [0, 0, 1], + 'right': [0.9, 0.9, 1.9], + 'top': [0, -1, 0], + 'bottom': [-0.8, -1.8, -0.8], + 'color': ['#45BF6F', '#2E6C8E', '#440154'], + 'released-loc': [0.18, 0.225, 1.0], + 'memory-loc': [0.54, 0.45, 1.0], + 'erred-loc': [0.54, 0.45, 1.9], + 'processing-loc': [0.72, 0.9, 1.9], + 'queued-loc': [0.9, 1.35, 1.9], + 'done': ['3 / 5', '2 / 4', '1 / 1']} """ width = 0.9 names = sorted(msg["all"], key=msg["all"].get, reverse=True) @@ -102,19 +108,28 @@ def progress_quads(msg, nrows=8, ncols=3): d["memory-loc"] = [] d["erred-loc"] = [] d["processing-loc"] = [] + d["queued-loc"] = [] d["done"] = [] - for r, m, e, p, a, l in zip( - d["released"], d["memory"], d["erred"], d["processing"], d["all"], d["left"] + for r, m, e, p, q, a, l in zip( + d["released"], + d["memory"], + d["erred"], + d["processing"], + d["queued"], + d["all"], + d["left"], ): rl = width * r / a + l ml = width * (r + m) / a + l el = width * (r + m + e) / a + l pl = width * (p + r + m + e) / a + l + ql = width * (p + r + m + e + q) / a + l done = "%d / %d" % (r + m + e, a) d["released-loc"].append(rl) d["memory-loc"].append(ml) d["erred-loc"].append(el) d["processing-loc"].append(pl) + d["queued-loc"].append(ql) d["done"].append(done) return d diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py index 73a5be81ab..49c93b212e 100644 --- a/distributed/diagnostics/tests/test_progress_stream.py +++ b/distributed/diagnostics/tests/test_progress_stream.py @@ -18,6 +18,7 @@ def test_progress_quads(): "erred": {"inc": 0, "dec": 1, "add": 0}, "released": {"inc": 1, "dec": 0, "add": 1}, "processing": {"inc": 1, "dec": 0, "add": 2}, + "queued": {"inc": 1, "dec": 0, "add": 2}, } d = progress_quads(msg, nrows=2) @@ -35,11 +36,13 @@ def test_progress_quads(): "memory": [2, 1, 0], "erred": [0, 0, 1], "processing": [1, 2, 0], + "queued": [1, 2, 0], "done": ["3 / 5", "2 / 4", "1 / 1"], "released-loc": [0.9 * 1 / 5, 0.25 * 0.9, 1.0], "memory-loc": [0.9 * 3 / 5, 0.5 * 0.9, 1.0], "erred-loc": [0.9 * 3 / 5, 0.5 * 0.9, 1.9], "processing-loc": [0.9 * 4 / 5, 1 * 0.9, 1 * 0.9 + 1], + "queued-loc": [1 * 0.9, 1.5 * 0.9, 1 * 0.9 + 1], } assert d == expected @@ -52,6 +55,7 @@ def test_progress_quads_too_many(): "erred": {k: 0 for k in keys}, "released": {k: 0 for k in keys}, "processing": {k: 0 for k in keys}, + "queued": {k: 0 for k in keys}, } d = progress_quads(msg, nrows=6, ncols=3) @@ -78,6 +82,7 @@ async def test_progress_stream(c, s, a, b): "memory": {"div": 9, "inc": 1}, "released": {"inc": 4}, "processing": {}, + "queued": {}, } assert set(nbytes) == set(msg["all"]) assert all(v > 0 for v in nbytes.values()) @@ -95,6 +100,7 @@ def test_progress_quads_many_functions(): "erred": {fn: 0 for fn in funcnames}, "released": {fn: 0 for fn in funcnames}, "processing": {fn: 0 for fn in funcnames}, + "queued": {fn: 0 for fn in funcnames}, } d = progress_quads(msg, nrows=2) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 7b5a7ed53e..a110aaa807 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -117,6 +117,28 @@ properties: description: | How frequently to balance worker loads + worker-saturation: + type: number + exclusiveMinimum: 0 + description: | + Controls how many root tasks are sent to workers (like a `readahead`). + + Up to worker-saturation * nthreads root tasks are sent to a + worker at a time. If `.inf`, all runnable tasks are immediately sent to workers. + + Allowing oversaturation (> 1.0) means a worker may start running a new root task as + soon as it completes the previous, even if there is a higher-priority downstream task + to run. This reduces worker idleness, by letting workers do something while waiting for + further instructions from the scheduler, even if it's not the most efficient + thing to do. + + This generally comes at the expense of increased memory usage. It leads to "wider" + (more breadth-first) execution of the graph. + + Compute-bound workloads may benefit from oversaturation. Memory-bound workloads should + generally leave `worker-saturation` at 1.0, though 1.25-1.5 could slightly improve + performance if ample memory is available. + worker-ttl: type: - string diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 400058148f..ca8292146f 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -22,6 +22,7 @@ distributed: events-log-length: 100000 work-stealing: True # workers should steal tasks from each other work-stealing-interval: 100ms # Callback time for work stealing + worker-saturation: .inf # Send this fraction of nthreads root tasks to workers worker-ttl: "5 minutes" # like '60s'. Time to live for workers. They must heartbeat faster than this pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings preload: [] # Run custom modules with Scheduler diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index c2d75d8206..158384b758 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -137,7 +137,15 @@ async def fetch_metrics(): ] return active_metrics, forgotten_tasks - expected = {"memory", "released", "processing", "waiting", "no-worker", "erred"} + expected = { + "memory", + "released", + "queued", + "processing", + "waiting", + "no-worker", + "erred", + } # Ensure that we get full zero metrics for all states even though the # scheduler did nothing, yet diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 10ca56e446..8ffb215025 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -58,6 +58,7 @@ from distributed._stories import scheduler_story from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend +from distributed.collections import HeapSet from distributed.comm import ( Comm, CommClosedError, @@ -116,6 +117,7 @@ "released", "waiting", "no-worker", + "queued", "processing", "memory", "erred", @@ -133,6 +135,7 @@ "released", "waiting", "no-worker", + "queued", "processing", "memory", "erred", @@ -844,14 +847,6 @@ class TaskGroup: #: The result types of this TaskGroup types: set[str] - #: The worker most recently assigned a task from this group, or None when the group - #: is not identified to be root-like by `SchedulerState.decide_worker`. - last_worker: WorkerState | None - - #: If `last_worker` is not None, the number of times that worker should be assigned - #: subsequent tasks until a new worker is chosen. - last_worker_tasks_left: int - prefix: TaskPrefix | None start: float stop: float @@ -870,8 +865,6 @@ def __init__(self, name: str): self.start = 0.0 self.stop = 0.0 self.all_durations = defaultdict(float) - self.last_worker = None - self.last_worker_tasks_left = 0 def add_duration(self, action: str, start: float, stop: float) -> None: duration = stop - start @@ -1281,13 +1274,15 @@ class SchedulerState: Tasks currently known to the scheduler * **unrunnable:** ``{TaskState}`` Tasks in the "no-worker" state + * **queued:** ``HeapSet[TaskState]`` + Tasks in the "queued" state, ordered by priority * **workers:** ``{worker key: WorkerState}`` Workers currently connected to the scheduler * **idle:** ``{WorkerState}``: - Set of workers that are not fully utilized + Set of workers that are currently in running state and not fully utilized * **saturated:** ``{WorkerState}``: - Set of workers that are not over-utilized + Set of workers that are fully utilized. May include non-running workers. * **running:** ``{WorkerState}``: Set of workers that are currently in running state @@ -1307,7 +1302,10 @@ class SchedulerState: "extensions", "host_info", "idle", + "last_root_worker", + "last_root_worker_tasks_left", "n_tasks", + "queued", "resources", "saturated", "running", @@ -1332,6 +1330,7 @@ class SchedulerState: "MEMORY_REBALANCE_SENDER_MIN", "MEMORY_REBALANCE_RECIPIENT_MAX", "MEMORY_REBALANCE_HALF_GAP", + "WORKER_SATURATION", } def __init__( @@ -1343,6 +1342,7 @@ def __init__( resources: dict, tasks: dict, unrunnable: set, + queued: HeapSet[TaskState], validate: bool, plugins: Iterable[SchedulerPlugin] = (), transition_counter_max: int | Literal[False] = False, @@ -1375,10 +1375,13 @@ def __init__( self.total_nthreads = 0 self.total_occupancy = 0.0 self.unknown_durations: dict[str, set[TaskState]] = {} + self.last_root_worker: WorkerState | None = None + self.last_root_worker_tasks_left: int = 0 + self.queued = queued self.unrunnable = unrunnable self.validate = validate self.workers = workers - self.running = { + self.running: set[WorkerState] = { ws for ws in self.workers.values() if ws.status == Status.running } self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} @@ -1403,6 +1406,9 @@ def __init__( dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") / 2.0 ) + self.WORKER_SATURATION = dask.config.get( + "distributed.scheduler.worker-saturation" + ) self.transition_counter = 0 self._idle_transition_counter = 0 self.transition_counter_max = transition_counter_max @@ -1418,6 +1424,7 @@ def __pdict__(self): "resources": self.resources, "saturated": self.saturated, "unrunnable": self.unrunnable, + "queued": self.queued, "n_tasks": self.n_tasks, "unknown_durations": self.unknown_durations, "validate": self.validate, @@ -1576,11 +1583,11 @@ def _transition( if not stimulus_id: stimulus_id = STIMULUS_ID_UNSET - finish2 = ts._state + actual_finish = ts._state # FIXME downcast antipattern scheduler = cast(Scheduler, self) scheduler.transition_log.append( - (key, start, finish2, recommendations, stimulus_id, time()) + (key, start, actual_finish, recommendations, stimulus_id, time()) ) if self.validate: if stimulus_id == STIMULUS_ID_UNSET: @@ -1591,8 +1598,8 @@ def _transition( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, start, - finish2, - ts.state, + finish, + actual_finish, dict(recommendations), ) if self.plugins: @@ -1603,7 +1610,7 @@ def _transition( self.tasks[ts.key] = ts for plugin in list(self.plugins.values()): try: - plugin.transition(key, start, finish2, *args, **kwargs) + plugin.transition(key, start, actual_finish, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts.state == "forgotten": @@ -1707,11 +1714,8 @@ def transition_released_waiting(self, key, stimulus_id): ts.waiters = {dts for dts in ts.dependents if dts.state == "waiting"} if not ts.waiting_on: - if self.workers: - recommendations[key] = "processing" - else: - self.unrunnable.add(ts) - ts.state = "no-worker" + # NOTE: waiting->processing will send tasks to queued or no-worker as necessary + recommendations[key] = "processing" return recommendations, client_msgs, worker_msgs except Exception as e: @@ -1722,43 +1726,21 @@ def transition_released_waiting(self, key, stimulus_id): pdb.set_trace() raise - def transition_no_worker_waiting(self, key, stimulus_id): + def transition_no_worker_processing(self, key, stimulus_id): try: ts: TaskState = self.tasks[key] - dts: TaskState - recommendations: dict = {} + recommendations: Recs = {} client_msgs: dict = {} worker_msgs: dict = {} if self.validate: + assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" assert ts in self.unrunnable - assert not ts.waiting_on - assert not ts.who_has - assert not ts.processing_on - - self.unrunnable.remove(ts) - - if ts.has_lost_dependencies: - recommendations[key] = "forgotten" - return recommendations, client_msgs, worker_msgs - for dts in ts.dependencies: - dep = dts.key - if not dts.who_has: - ts.waiting_on.add(dts) - if dts.state == "released": - recommendations[dep] = "waiting" - else: - dts.waiters.add(ts) - - ts.state = "waiting" - - if not ts.waiting_on: - if self.workers: - recommendations[key] = "processing" - else: - self.unrunnable.add(ts) - ts.state = "no-worker" + if ws := self.decide_worker_non_rootish(ts): + self.unrunnable.discard(ts) + worker_msgs = _add_to_processing(self, ts, ws) + # If no worker, task just stays in `no-worker` return recommendations, client_msgs, worker_msgs except Exception as e: @@ -1811,70 +1793,153 @@ def transition_no_worker_memory( pdb.set_trace() raise - def decide_worker(self, ts: TaskState) -> WorkerState | None: - """ - Decide on a worker for task *ts*. Return a WorkerState. + def decide_worker_rootish_queuing_disabled( + self, ts: TaskState + ) -> WorkerState | None: + """Pick a worker for a runnable root-ish task, without queuing. - If it's a root or root-like task, we place it with its relatives to - reduce future data tansfer. + This attempts to schedule sibling tasks on the same worker, reducing future data + transfer. It does not consider the location of dependencies, since they'll end + up on every worker anyway. - If it has dependencies or restrictions, we use - `decide_worker_from_deps_and_restrictions`. + It assumes it's being called on a batch of tasks in priority order, and + maintains state in `SchedulerState.last_root_worker` and + `SchedulerState.last_root_worker_tasks_left` to achieve this. - Otherwise, we pick the least occupied worker, or pick from all workers - in a round-robin fashion. - """ - if not self.workers: - return None + This will send every runnable task to a worker, often causing root task + overproduction. - tg = ts.group - valid_workers = self.valid_workers(ts) + Returns + ------- + ws: WorkerState | None + The worker to assign the task to. If there are no workers in the cluster, + returns None, in which case the task should be transitioned to + ``no-worker``. + """ + if self.validate: + # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` + assert math.isinf(self.WORKER_SATURATION) - if ( - valid_workers is not None - and not valid_workers - and not ts.loose_restrictions - ): - self.unrunnable.add(ts) - ts.state = "no-worker" + pool = self.idle.values() if self.idle else self.running + if not pool: return None - # Group is larger than cluster with few dependencies? - # Minimize future data transfers. - if ( - valid_workers is None - and len(tg) > self.total_nthreads * 2 - and len(tg.dependencies) < 5 - and sum(map(len, tg.dependencies)) < 5 + lws = self.last_root_worker + if not ( + lws + and self.last_root_worker_tasks_left + and self.workers.get(lws.address) is lws ): - ws = tg.last_worker + # Last-used worker is full or unknown; pick a new worker for the next few tasks + ws = self.last_root_worker = min( + pool, key=lambda ws: len(ws.processing) / ws.nthreads + ) + # TODO better batching metric (`len(tg)` is not necessarily the total number of root tasks!) + self.last_root_worker_tasks_left = math.floor( + (len(ts.group) / self.total_nthreads) * ws.nthreads + ) + else: + ws = lws - if not (ws and tg.last_worker_tasks_left and ws.address in self.workers): - # Last-used worker is full or unknown; pick a new worker for the next few tasks - ws = min( - (self.idle or self.workers).values(), - key=partial(self.worker_objective, ts), - ) - assert ws - tg.last_worker_tasks_left = math.floor( - (len(tg) / self.total_nthreads) * ws.nthreads - ) + self.last_root_worker_tasks_left -= 1 + + if self.validate and ws is not None: + assert self.workers.get(ws.address) is ws + assert ws in self.running, (ws, self.running) + + return ws + + def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None: + """Pick a worker for a runnable root-ish task, if not all are busy. + + Picks the least-busy worker out of the ``idle`` workers (idle workers have fewer + tasks running than threads, as set by ``distributed.scheduler.worker-saturation``). + It does not consider the location of dependencies, since they'll end up on every + worker anyway. + + If all workers are full, returns None, meaning the task should transition to + ``queued``. The scheduler will wait to send it to a worker until a thread opens + up. This ensures that downstream tasks always run before new root tasks are + started. + + This does not try to schedule sibling tasks on the same worker; in fact, it + usually does the opposite. Even though this increases subsequent data transfer, + it typically reduces overall memory use by eliminating root task overproduction. + + Returns + ------- + ws: WorkerState | None + The worker to assign the task to. If there are no idle workers, returns + None, in which case the task should be transitioned to ``queued``. + + """ + if self.validate: + # We don't `assert self.is_rootish(ts)` here, because that check is dependent on + # cluster size. It's possible a task looked root-ish when it was queued, but the + # cluster has since scaled up and it no longer does when coming out of the queue. + # If `is_rootish` changes to a static definition, then add that assertion here + # (and actually pass in the task). + assert not math.isinf(self.WORKER_SATURATION) + + if not self.idle: + # All workers busy? Task gets/stays queued. + return None - # Record `last_worker`, or clear it on the final task - tg.last_worker = ( - ws if tg.states["released"] + tg.states["waiting"] > 1 else None + # Just pick the least busy worker. + # NOTE: this will lead to worst-case scheduling with regards to co-assignment. + ws = min(self.idle.values(), key=lambda ws: len(ws.processing) / ws.nthreads) + if self.validate: + assert not _worker_full(ws, self.WORKER_SATURATION), ( + ws, + _task_slots_available(ws, self.WORKER_SATURATION), ) - tg.last_worker_tasks_left -= 1 - return ws + assert ws in self.running, (ws, self.running) + + if self.validate and ws is not None: + assert self.workers.get(ws.address) is ws + assert ws in self.running, (ws, self.running) + + return ws + + def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None: + """Pick a worker for a runnable non-root task, considering dependencies and restrictions. + + Out of eligible workers holding dependencies of ``ts``, selects the worker + where, considering worker backlong and data-transfer costs, the task is + estimated to start running the soonest. + + Returns + ------- + ws: WorkerState | None + The worker to assign the task to. If no workers satisfy the restrictions of + ``ts`` or there are no running workers, returns None, in which case the task + should be transitioned to ``no-worker``. + """ + if not self.running: + return None + + valid_workers = self.valid_workers(ts) + if valid_workers is None and len(self.running) < len(self.workers): + if not self.running: + return None + + # If there were no restrictions, `valid_workers()` didn't subset by `running`. + valid_workers = self.running if ts.dependencies or valid_workers is not None: ws = decide_worker( ts, - self.workers.values(), + self.running, valid_workers, partial(self.worker_objective, ts), ) else: + # TODO if `is_rootish` would always return True for tasks without dependencies, + # we could remove all this logic. The rootish assignment logic would behave + # more or less the same as this, maybe without gauranteed round-robin though? + # This path is only reachable when `ts` doesn't have dependencies, but its + # group is also smaller than the cluster. + # Fastpath when there are no related tasks or restrictions worker_pool = self.idle or self.workers wp_vals = worker_pool.values() @@ -1898,46 +1963,36 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None: ws = wp_vals[self.n_tasks % n_workers] if self.validate and ws is not None: - assert ws.address in self.workers + assert self.workers.get(ws.address) is ws + assert ws in self.running, (ws, self.running) return ws def transition_waiting_processing(self, key, stimulus_id): + """Possibly schedule a ready task. This is the primary dispatch for ready tasks. + + If there's no appropriate worker for the task (but the task is otherwise runnable), + it will be recommended to ``no-worker`` or ``queued``. + """ try: ts: TaskState = self.tasks[key] - dts: TaskState - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} - if self.validate: - assert not ts.waiting_on - assert not ts.who_has - assert not ts.exception_blame - assert not ts.processing_on - assert not ts.has_lost_dependencies - assert ts not in self.unrunnable - assert all(dts.who_has for dts in ts.dependencies) - - ws = self.decide_worker(ts) - if ws is None: - return recommendations, client_msgs, worker_msgs - worker = ws.address - - self._set_duration_estimate(ts, ws) - ts.processing_on = ws - ts.state = "processing" - self.acquire_resources(ts, ws) - self.check_idle_saturated(ws) - self.n_tasks += 1 - if ts.actor: - ws.actors.add(ts) - - # logger.debug("Send job to worker: %s, %s", worker, key) - - worker_msgs[worker] = [_task_to_msg(self, ts)] + if self.is_rootish(ts): + # NOTE: having two root-ish methods is temporary. When the feature flag is removed, + # there should only be one, which combines co-assignment and queuing. + # Eventually, special-casing root tasks might be removed entirely, with better heuristics. + if math.isinf(self.WORKER_SATURATION): + if not (ws := self.decide_worker_rootish_queuing_disabled(ts)): + return {ts.key: "no-worker"}, {}, {} + else: + if not (ws := self.decide_worker_rootish_queuing_enabled()): + return {ts.key: "queued"}, {}, {} + else: + if not (ws := self.decide_worker_non_rootish(ts)): + return {ts.key: "no-worker"}, {}, {} - return recommendations, client_msgs, worker_msgs + worker_msgs = _add_to_processing(self, ts, ws) + return {}, {}, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -2070,7 +2125,10 @@ def transition_processing_memory( if nbytes is not None: ts.set_nbytes(nbytes) - _remove_from_processing(self, ts) + # NOTE: recommendations for queued tasks are added first, so they'll be popped last, + # allowing higher-priority downstream tasks to be transitioned first. + # FIXME: this would be incorrect if queued tasks are user-annotated as higher priority. + _exit_processing_common(self, ts, recommendations) _add_to_memory( self, ts, ws, recommendations, client_msgs, type=type, typename=typename @@ -2292,7 +2350,7 @@ def transition_waiting_released(self, key, stimulus_id): def transition_processing_released(self, key: str, stimulus_id: str): try: ts = self.tasks[key] - recommendations = {} + recommendations: Recs = {} worker_msgs = {} if self.validate: @@ -2301,7 +2359,7 @@ def transition_processing_released(self, key: str, stimulus_id: str): assert not ts.waiting_on assert ts.state == "processing" - ws = _remove_from_processing(self, ts) + ws = _exit_processing_common(self, ts, recommendations) if ws: worker_msgs[ws.address] = [ { @@ -2311,24 +2369,7 @@ def transition_processing_released(self, key: str, stimulus_id: str): } ] - ts.state = "released" - - if ts.has_lost_dependencies: - recommendations[key] = "forgotten" - elif ts.waiters or ts.who_wants: - recommendations[key] = "waiting" - - if recommendations.get(key) != "waiting": - for dts in ts.dependencies: - if dts.state != "released": - dts.waiters.discard(ts) - if not dts.waiters and not dts.who_wants: - recommendations[dts.key] = "released" - ts.waiters.clear() - - if self.validate: - assert not ts.processing_on - + _propagage_released(self, ts, recommendations) return recommendations, {}, worker_msgs except Exception as e: logger.exception(e) @@ -2395,7 +2436,7 @@ def transition_processing_erred( ws = ts.processing_on ws.actors.remove(ts) - _remove_from_processing(self, ts) + _exit_processing_common(self, ts, recommendations) ts.erred_on.add(worker) if exception is not None: @@ -2494,6 +2535,99 @@ def transition_no_worker_released(self, key, stimulus_id): pdb.set_trace() raise + def transition_waiting_queued(self, key, stimulus_id): + try: + ts: TaskState = self.tasks[key] + recommendations: Recs = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + if self.validate: + assert not self.idle, (ts, self.idle) + _validate_ready(self, ts) + + ts.state = "queued" + self.queued.add(ts) + + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + + def transition_waiting_no_worker(self, key, stimulus_id): + try: + ts: TaskState = self.tasks[key] + recommendations: Recs = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + if self.validate: + _validate_ready(self, ts) + + ts.state = "no-worker" + self.unrunnable.add(ts) + + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + + def transition_queued_released(self, key, stimulus_id): + try: + ts: TaskState = self.tasks[key] + recommendations: Recs = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + if self.validate: + assert ts in self.queued + assert not ts.processing_on + + self.queued.remove(ts) + + _propagage_released(self, ts, recommendations) + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + + def transition_queued_processing(self, key, stimulus_id): + try: + ts: TaskState = self.tasks[key] + recommendations: Recs = {} + client_msgs: dict = {} + worker_msgs: dict = {} + + if self.validate: + assert not ts.actor, f"Actors can't be queued: {ts}" + assert ts in self.queued + + if ws := self.decide_worker_rootish_queuing_enabled(): + self.queued.discard(ts) + worker_msgs = _add_to_processing(self, ts, ws) + # If no worker, task just stays `queued` + + return recommendations, client_msgs, worker_msgs + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + def _remove_key(self, key): ts: TaskState = self.tasks.pop(key) assert ts.state == "forgotten" @@ -2559,6 +2693,7 @@ def transition_released_forgotten(self, key, stimulus_id): assert ts.state in ("released", "erred") assert not ts.who_has assert not ts.processing_on + assert ts not in self.queued assert not ts.waiting_on, (ts, ts.waiting_on) if not ts.run_spec: # It's ok to forget a pure data task @@ -2601,12 +2736,16 @@ def transition_released_forgotten(self, key, stimulus_id): ("released", "waiting"): transition_released_waiting, ("waiting", "released"): transition_waiting_released, ("waiting", "processing"): transition_waiting_processing, + ("waiting", "no-worker"): transition_waiting_no_worker, + ("waiting", "queued"): transition_waiting_queued, ("waiting", "memory"): transition_waiting_memory, + ("queued", "released"): transition_queued_released, + ("queued", "processing"): transition_queued_processing, ("processing", "released"): transition_processing_released, ("processing", "memory"): transition_processing_memory, ("processing", "erred"): transition_processing_erred, ("no-worker", "released"): transition_no_worker_released, - ("no-worker", "waiting"): transition_no_worker_waiting, + ("no-worker", "processing"): transition_no_worker_processing, ("no-worker", "memory"): transition_no_worker_memory, ("released", "forgotten"): transition_released_forgotten, ("memory", "forgotten"): transition_memory_forgotten, @@ -2619,6 +2758,23 @@ def transition_released_forgotten(self, key, stimulus_id): # Assigning Tasks to Workers # ############################## + def is_rootish(self, ts: TaskState) -> bool: + """ + Whether ``ts`` is a root or root-like task. + + Root-ish tasks are part of a group that's much larger than the cluster, + and have few or no dependencies. + """ + if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions: + return False + tg = ts.group + # TODO short-circuit to True if `not ts.dependencies`? + return ( + len(tg) > self.total_nthreads * 2 + and len(tg.dependencies) < 5 + and sum(map(len, tg.dependencies)) < 5 + ) + def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None: """Estimate task duration using worker state and task state. @@ -2659,6 +2815,15 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): all of their threads, and if the expected runtime of those tasks is large enough. + If ``distributed.scheduler.worker-saturation`` is not ``inf`` + (scheduler-side queuing is enabled), they are considered idle + if they have fewer tasks processing than the ``worker-saturation`` + threshold dictates. + + Otherwise, they are considered idle if they have fewer tasks processing + than threads, or if their tasks' total expected runtime is less than half + the expected runtime of the same number of average tasks. + This is useful for load balancing and adaptivity. """ if self.total_nthreads == 0 or ws.status == Status.closed: @@ -2672,8 +2837,13 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): idle = self.idle saturated = self.saturated - if p < nc or occ < nc * avg / 2: - idle[ws.address] = ws + if ( + (p < nc or occ < nc * avg / 2) + if math.isinf(self.WORKER_SATURATION) + else not _worker_full(ws, self.WORKER_SATURATION) + ): + if ws.status == Status.running: + idle[ws.address] = ws saturated.discard(ws) else: idle.pop(ws.address, None) @@ -2730,7 +2900,10 @@ def get_task_duration(self, ts: TaskState) -> float: def valid_workers(self, ts: TaskState) -> set[WorkerState] | None: """Return set of currently valid workers for key - If all workers are valid then this returns ``None``. + If all workers are valid then this returns ``None``, in which case + any *running* worker can be used. + Otherwise, the subset of running workers valid for this task + is returned. This checks tracks the following state: * worker_restrictions @@ -2779,10 +2952,7 @@ def valid_workers(self, ts: TaskState) -> set[WorkerState] | None: else: s &= ww - if s is None: - if len(self.running) < len(self.workers): - return self.running.copy() - else: + if s: s = {self.workers[addr] for addr in s} if len(self.running) < len(self.workers): s &= self.running @@ -2875,21 +3045,36 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): for ts in ws.processing: steal.recalculate_cost(ts) - def bulk_schedule_after_adding_worker(self, ws: WorkerState): - """Send tasks with ts.state=='no-worker' in bulk to a worker that just joined. - Return recommendations. As the worker will start executing the new tasks - immediately, without waiting for the batch to end, we can't rely on worker-side - ordering, so the recommendations are sorted by priority order here. + def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs: + """Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can handle. + + Returns priority-ordered recommendations. """ - tasks = [] + maybe_runnable: list[TaskState] = [] + # Schedule any queued tasks onto the new worker + if not math.isinf(self.WORKER_SATURATION) and self.queued: + for qts in reversed( + list( + self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)) + ) + ): + if self.validate: + assert qts.state == "queued" + assert not qts.processing_on + assert not qts.waiting_on + + maybe_runnable.append(qts) + + # Schedule any restricted tasks onto the new worker, if the worker can run them for ts in self.unrunnable: valid = self.valid_workers(ts) if valid is None or ws in valid: - tasks.append(ts) - # These recommendations will generate {"op": "compute-task"} messages - # to the worker in reversed order - tasks.sort(key=operator.attrgetter("priority"), reverse=True) - return {ts.key: "waiting" for ts in tasks} + maybe_runnable.append(ts) + + # Recommendations are processed LIFO, hence the reversed order + maybe_runnable.sort(key=operator.attrgetter("priority"), reverse=True) + # Note not all will necessarily be run; transition->processing will decide + return {ts.key: "processing" for ts in maybe_runnable} class Scheduler(SchedulerState, ServerNode): @@ -3129,6 +3314,7 @@ def __init__( self._last_client = None self._last_time = 0 unrunnable = set() + queued: HeapSet[TaskState] = HeapSet(key=operator.attrgetter("priority")) self.datasets = {} @@ -3263,6 +3449,7 @@ def __init__( resources=resources, tasks=tasks, unrunnable=unrunnable, + queued=queued, validate=validate, plugins=plugins, transition_counter_max=transition_counter_max, @@ -3806,9 +3993,6 @@ async def add_worker( self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) - if ws.nthreads > len(ws.processing): - self.idle[ws.address] = ws - for plugin in list(self.plugins.values()): try: result = plugin.add_worker(scheduler=self, worker=address) @@ -4536,6 +4720,7 @@ def validate_released(self, key): assert not ts.processing_on assert not any([ts in dts.waiters for dts in ts.dependencies]) assert ts not in self.unrunnable + assert ts not in self.queued def validate_waiting(self, key): ts: TaskState = self.tasks[key] @@ -4543,19 +4728,35 @@ def validate_waiting(self, key): assert not ts.who_has assert not ts.processing_on assert ts not in self.unrunnable + assert ts not in self.queued for dts in ts.dependencies: # We are waiting on a dependency iff it's not stored assert bool(dts.who_has) != (dts in ts.waiting_on) assert ts in dts.waiters # XXX even if dts._who_has? + def validate_queued(self, key): + ts: TaskState = self.tasks[key] + dts: TaskState + assert ts in self.queued + assert not ts.waiting_on + assert not ts.who_has + assert not ts.processing_on + assert not ( + ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions + ) + for dts in ts.dependencies: + assert dts.who_has + assert ts in dts.waiters + def validate_processing(self, key): ts: TaskState = self.tasks[key] dts: TaskState assert not ts.waiting_on - ws: WorkerState = ts.processing_on + ws = ts.processing_on assert ws assert ts in ws.processing assert not ts.who_has + assert ts not in self.queued for dts in ts.dependencies: assert dts.who_has assert ts in dts.waiters @@ -4568,9 +4769,10 @@ def validate_memory(self, key): assert not ts.processing_on assert not ts.waiting_on assert ts not in self.unrunnable + assert ts not in self.queued for dts in ts.dependents: assert (dts in ts.waiters) == ( - dts.state in ("waiting", "processing", "no-worker") + dts.state in ("waiting", "queued", "processing", "no-worker") ) assert ts not in dts.waiting_on @@ -4581,6 +4783,7 @@ def validate_no_worker(self, key): assert ts in self.unrunnable assert not ts.processing_on assert not ts.who_has + assert ts not in self.queued for dts in ts.dependencies: assert dts.who_has @@ -4588,6 +4791,7 @@ def validate_erred(self, key): ts: TaskState = self.tasks[key] assert ts.exception_blame assert not ts.who_has + assert ts not in self.queued def validate_key(self, key, ts: TaskState | None = None): try: @@ -4619,13 +4823,21 @@ def validate_state(self, allow_overlap: bool = False) -> None: if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") + assert self.running.issuperset(self.idle.values()), ( + self.running, + list(self.idle.values()), + ) for w, ws in self.workers.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) assert ws.address == w + if ws.status != Status.running: + assert ws.address not in self.idle + assert ws.long_running.issubset(ws.processing) if not ws.processing: assert not ws.occupancy - assert ws.address in self.idle + if ws.status == Status.running: + assert ws.address in self.idle assert (ws.status == Status.running) == (ws in self.running) for ws in self.running: @@ -4892,13 +5104,13 @@ def handle_long_running( self.check_idle_saturated(ws) def handle_worker_status_change( - self, status: str, worker: str, stimulus_id: str + self, status: str | Status, worker: str | WorkerState, stimulus_id: str ) -> None: - ws = self.workers.get(worker) + ws = self.workers.get(worker) if isinstance(worker, str) else worker if not ws: return prev_status = ws.status - ws.status = Status.lookup[status] # type: ignore + ws.status = Status[status] if isinstance(status, str) else status if ws.status == prev_status: return @@ -4907,12 +5119,14 @@ def handle_worker_status_change( { "action": "worker-status-change", "prev-status": prev_status.name, - "status": status, + "status": ws.status.name, }, ) + logger.debug(f"Worker status {prev_status.name} -> {status} - {ws}") if ws.status == Status.running: self.running.add(ws) + self.check_idle_saturated(ws) recs = self.bulk_schedule_after_adding_worker(ws) if recs: client_msgs: dict = {} @@ -4921,6 +5135,7 @@ def handle_worker_status_change( self.send_all(client_msgs, worker_msgs) else: self.running.discard(ws) + self.idle.pop(ws.address, None) async def handle_request_refresh_who_has( self, keys: Iterable[str], worker: str, stimulus_id: str @@ -6235,8 +6450,9 @@ async def retire_workers( # Change Worker.status to closing_gracefully. Immediately set # the same on the scheduler to prevent race conditions. prev_status = ws.status - ws.status = Status.closing_gracefully - self.running.discard(ws) + self.handle_worker_status_change( + Status.closing_gracefully, ws, stimulus_id + ) # FIXME: We should send a message to the nanny first; # eventually workers won't be able to close their own nannies. self.stream_comms[ws.address].send( @@ -7272,7 +7488,11 @@ def check_idle(self): self.idle_since = None return - if any([ws.processing for ws in self.workers.values()]) or self.unrunnable: + if ( + self.queued + or self.unrunnable + or any([ws.processing for ws in self.workers.values()]) + ): self.idle_since = None return @@ -7308,21 +7528,24 @@ def adaptive_target(self, target_duration=None): target_duration = parse_timedelta(target_duration) # CPU + + # TODO consider any user-specified default task durations for queued tasks + queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION cpu = math.ceil( - self.total_occupancy / target_duration + (self.total_occupancy + queued_occupancy) / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores - tasks_processing = 0 + tasks_ready = len(self.queued) for ws in self.workers.values(): - tasks_processing += len(ws.processing) + tasks_ready += len(ws.processing) - if tasks_processing > cpu: + if tasks_ready > cpu: break else: - cpu = min(tasks_processing, cpu) + cpu = min(tasks_ready, cpu) - if self.unrunnable and not self.workers: + if (self.unrunnable or self.queued) and not self.workers: cpu = max(1, cpu) # add more workers if more than 60% of memory is used @@ -7397,7 +7620,44 @@ def request_remove_replicas( ) -def _remove_from_processing(state: SchedulerState, ts: TaskState) -> WorkerState | None: +def _validate_ready(state: SchedulerState, ts: TaskState): + "Validation for ready states (processing, queued, no-worker)" + assert not ts.waiting_on + assert not ts.who_has + assert not ts.exception_blame + assert not ts.processing_on + assert not ts.has_lost_dependencies + assert ts not in state.unrunnable + assert ts not in state.queued + assert all(dts.who_has for dts in ts.dependencies) + + +def _add_to_processing( + state: SchedulerState, ts: TaskState, ws: WorkerState +) -> dict[str, list]: + "Set a task as processing on a worker, and return the worker messages to send." + if state.validate: + _validate_ready(state, ts) + assert ts not in ws.processing + assert ws in state.running, state.running + assert (o := state.workers.get(ws.address)) is ws, (ws, o) + + state._set_duration_estimate(ts, ws) + ts.processing_on = ws + ts.state = "processing" + state.acquire_resources(ts, ws) + state.check_idle_saturated(ws) + state.n_tasks += 1 + + if ts.actor: + ws.actors.add(ts) + + return {ws.address: [_task_to_msg(state, ts)]} + + +def _exit_processing_common( + state: SchedulerState, ts: TaskState, recommendations: Recs +) -> WorkerState | None: """Remove *ts* from the set of processing tasks. Returns @@ -7428,6 +7688,18 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> WorkerState state.check_idle_saturated(ws) state.release_resources(ts, ws) + # If a slot has opened up for a queued task, schedule it. + if state.queued and not _worker_full(ws, state.WORKER_SATURATION): + qts = state.queued.peek() + if state.validate: + assert qts.state == "queued", qts.state + assert qts.key not in recommendations, recommendations[qts.key] + + # NOTE: we don't need to schedule more than one task at once here. Since this is + # called each time 1 task completes, multiple tasks must complete for multiple + # slots to open up. + recommendations[qts.key] = "processing" + return ws @@ -7489,6 +7761,32 @@ def _add_to_memory( ) +def _propagage_released( + state: SchedulerState, + ts: TaskState, + recommendations: Recs, +) -> None: + ts.state = "released" + key = ts.key + + if ts.has_lost_dependencies: + recommendations[key] = "forgotten" + elif ts.waiters or ts.who_wants: + recommendations[key] = "waiting" + + if recommendations.get(key) != "waiting": + for dts in ts.dependencies: + if dts.state != "released": + dts.waiters.discard(ts) + if not dts.waiters and not dts.who_wants: + recommendations[dts.key] = "released" + ts.waiters.clear() + + if state.validate: + assert not ts.processing_on + assert ts not in state.queued + + def _propagate_forgotten( state: SchedulerState, ts: TaskState, @@ -7686,7 +7984,7 @@ def validate_task_state(ts: TaskState) -> None: str(dts), str(dts.dependents), ) - if ts.state in ("waiting", "processing", "no-worker"): + if ts.state in ("waiting", "queued", "processing", "no-worker"): assert dts in ts.waiting_on or dts.who_has, ( "dep missing", str(ts), @@ -7695,7 +7993,7 @@ def validate_task_state(ts: TaskState) -> None: assert dts.state != "forgotten" for dts in ts.waiters: - assert dts.state in ("waiting", "processing", "no-worker"), ( + assert dts.state in ("waiting", "queued", "processing", "no-worker"), ( "waiter not in play", str(ts), str(dts), @@ -7712,6 +8010,15 @@ def validate_task_state(ts: TaskState) -> None: assert (ts.processing_on is not None) == (ts.state == "processing") assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state) + if ts.state == "queued": + assert not ts.processing_on + assert not ts.who_has + assert all(dts.who_has for dts in ts.dependencies), ( + "task queued without all deps", + str(ts), + str(ts.dependencies), + ) + if ts.state == "processing": assert all(dts.who_has for dts in ts.dependencies), ( "task processing without all deps", @@ -7752,6 +8059,7 @@ def validate_task_state(ts: TaskState) -> None: if ts.state == "processing": assert ts.processing_on assert ts in ts.processing_on.actors + assert ts.state != "queued" def validate_worker_state(ws: WorkerState) -> None: @@ -7806,6 +8114,21 @@ def heartbeat_interval(n: int) -> float: return n / 200 + 1 +def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int: + "Number of tasks that can be sent to this worker without oversaturating it" + assert not math.isinf(saturation_factor) + nthreads = ws.nthreads + return max(int(saturation_factor * nthreads), 1) - ( + len(ws.processing) - len(ws.long_running) + ) + + +def _worker_full(ws: WorkerState, saturation_factor: float) -> bool: + if math.isinf(saturation_factor): + return False + return _task_slots_available(ws, saturation_factor) <= 0 + + class KilledWorker(Exception): def __init__(self, task: str, last_worker: WorkerState): super().__init__(task, last_worker) diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index e99297a2a0..4c3c513adc 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -782,6 +782,8 @@ async def test_RetireWorker_no_remove(c, s, a, b): while s.tasks["x"].who_has != {s.workers[b.address]}: await asyncio.sleep(0.01) assert a.address in s.workers + assert a.status == Status.closing_gracefully + assert s.workers[a.address].status == Status.closing_gracefully # Policy has been removed without waiting for worker to disappear from # Scheduler.workers assert not s.extensions["amm"].policies diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 05850924f8..f83e5f45bb 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3,6 +3,7 @@ import asyncio import json import logging +import math import operator import pickle import re @@ -10,7 +11,7 @@ from itertools import product from textwrap import dedent from time import sleep -from typing import Collection +from typing import ClassVar, Collection from unittest import mock import cloudpickle @@ -43,6 +44,7 @@ from distributed.utils import TimeoutError from distributed.utils_test import ( BrokenComm, + async_wait_for, captured_logger, cluster, dec, @@ -54,6 +56,7 @@ raises_with_cause, slowadd, slowdec, + slowidentity, slowinc, tls_only_security, varying, @@ -145,7 +148,9 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): @gen_cluster( client=True, nthreads=nthreads, - config={"distributed.scheduler.work-stealing": False}, + config={ + "distributed.scheduler.work-stealing": False, + }, ) async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers): r""" @@ -245,6 +250,223 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() +@pytest.mark.parametrize("ngroups", [1, 2, 3, 5]) +@gen_cluster( + client=True, + nthreads=[("", 1), ("", 1)], +) +async def test_decide_worker_coschedule_order_binary_op(c, s, a, b, ngroups): + roots = [[delayed(i, name=f"x-{n}-{i}") for i in range(8)] for n in range(ngroups)] + zs = [sum(rs) for rs in zip(*roots)] + + await c.gather(c.compute(zs)) + + assert not a.transfer_incoming_log, [l["keys"] for l in a.transfer_incoming_log] + assert not b.transfer_incoming_log, [l["keys"] for l in b.transfer_incoming_log] + + +@pytest.mark.slow +@gen_cluster( + nthreads=[("", 2)] * 4, + client=True, + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_graph_execution_width(c, s, *workers): + """ + Test that we don't execute the graph more breadth-first than necessary. + + We shouldn't start loading extra data if we're not going to use it immediately. + The number of parallel work streams match the number of threads. + """ + + class Refcount: + "Track how many instances of this class exist; logs the count at creation and deletion" + + count: ClassVar[int] = 0 + lock: ClassVar[dask.utils.SerializableLock] = dask.utils.SerializableLock() + log: ClassVar[list[int]] = [] + + def __init__(self) -> None: + with self.lock: + type(self).count += 1 + self.log.append(self.count) + + def __del__(self): + with self.lock: + self.log.append(self.count) + type(self).count -= 1 + + roots = [delayed(Refcount)() for _ in range(32)] + passthrough1 = [delayed(slowidentity)(r, delay=0) for r in roots] + passthrough2 = [delayed(slowidentity)(r, delay=0) for r in passthrough1] + done = [delayed(lambda r: None)(r) for r in passthrough2] + + fs = c.compute(done) + await wait(fs) + # NOTE: the max should normally equal `total_nthreads`. But some macOS CI machines + # are slow enough that they aren't able to reach the full parallelism of 8 threads. + assert max(Refcount.log) <= s.total_nthreads + + +@pytest.mark.parametrize("queue", [True, False]) +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + config={ + "distributed.worker.memory.pause": False, + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.scheduler.work-stealing": False, + }, +) +async def test_queued_paused_new_worker(c, s, a, b, queue): + if queue: + s.WORKER_SATURATION = 1.0 + else: + s.WORKER_SATURATION = float("inf") + + f1s = c.map(slowinc, range(16)) + f2s = c.map(slowinc, f1s) + final = c.submit(sum, *f2s) + del f1s, f2s + + while not a.data or not b.data: + await asyncio.sleep(0.01) + + # manually pause the workers + a.status = Status.paused + b.status = Status.paused + + while s.running: + # wait for workers pausing to hit the scheduler + await asyncio.sleep(0.01) + + assert not s.idle + assert not s.running + + async with Worker(s.address, nthreads=2) as w: + # Tasks are successfully scheduled onto a new worker + while not w.state.data: + await asyncio.sleep(0.01) + + del final + while s.tasks: + await asyncio.sleep(0.01) + assert not s.queued + + +@pytest.mark.parametrize("queue", [True, False]) +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + config={ + "distributed.worker.memory.pause": False, + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.scheduler.work-stealing": False, + }, +) +async def test_queued_paused_unpaused(c, s, a, b, queue): + if queue: + s.WORKER_SATURATION = 1.0 + else: + s.WORKER_SATURATION = float("inf") + + f1s = c.map(slowinc, range(16)) + f2s = c.map(slowinc, f1s) + final = c.submit(sum, *f2s) + del f1s, f2s + + while not a.data or not b.data: + await asyncio.sleep(0.01) + + # manually pause the workers + a.status = Status.paused + b.status = Status.paused + + while s.running: + # wait for workers pausing to hit the scheduler + await asyncio.sleep(0.01) + + assert not s.running + assert not s.idle + + # un-pause + a.status = Status.running + b.status = Status.running + while not s.running: + await asyncio.sleep(0.01) + + assert not s.idle # workers should have been (or already were) filled + + await wait(final) + + +@gen_cluster( + client=True, + nthreads=[("", 2)] * 2, + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_queued_remove_add_worker(c, s, a, b): + event = Event() + fs = c.map(lambda i: event.wait(), range(10)) + + await async_wait_for(lambda: len(s.queued) == 6, timeout=5) + await s.remove_worker(a.address, stimulus_id="fake") + assert len(s.queued) == 8 + + # Add a new worker + async with Worker(s.address, nthreads=2) as w: + await async_wait_for(lambda: len(s.queued) == 6, timeout=5) + + await event.set() + await wait(fs) + + +@pytest.mark.parametrize( + "saturation, expected_task_counts", + [ + (2.5, (5, 2)), + (2.0, (4, 2)), + (1.0, (2, 1)), + (-1.0, (1, 1)), + (float("inf"), (6, 4)) + # ^ depends on root task assignment logic; ok if changes, just needs to add up to 10 + ], +) +def test_saturation_factor( + saturation: int | float, expected_task_counts: tuple[int, int] +) -> None: + @gen_cluster( + client=True, + nthreads=[("", 2), ("", 1)], + config={ + "distributed.scheduler.worker-saturation": saturation, + }, + ) + async def _test_saturation_factor(c, s, a, b): + event = Event() + fs = c.map( + lambda _: event.wait(), range(10), key=[f"wait-{i}" for i in range(10)] + ) + while a.state.executing_count < min( + a.state.nthreads, expected_task_counts[0] + ) or b.state.executing_count < min(b.state.nthreads, expected_task_counts[1]): + await asyncio.sleep(0.01) + + if math.isfinite(saturation): + assert len(a.state.tasks) == expected_task_counts[0] + assert len(b.state.tasks) == expected_task_counts[1] + else: + # Assignment is nondeterministic for some reason without queuing + assert len(a.state.tasks) > len(b.state.tasks) + + await event.set() + await c.gather(fs) + + _test_saturation_factor() + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = await client.scatter([1], workers=b.address) @@ -284,17 +506,32 @@ async def test_no_valid_workers_loose_restrictions(client, s, a, b, c): assert result == 2 +@pytest.mark.parametrize("queue", [False, True]) @gen_cluster(client=True, nthreads=[]) -async def test_no_workers(client, s): +async def test_no_workers(client, s, queue): + if queue: + s.WORKER_SATURATION = 1.0 + else: + s.WORKER_SATURATION = float("inf") + x = client.submit(inc, 1) while not s.tasks: await asyncio.sleep(0.01) - assert s.tasks[x.key] in s.unrunnable + ts = s.tasks[x.key] + if queue: + assert ts in s.queued + assert ts.state == "queued" + else: + assert ts in s.unrunnable + assert ts.state == "no-worker" with pytest.raises(TimeoutError): await asyncio.wait_for(x, 0.05) + async with Worker(s.address, nthreads=1): + await wait(x) + @gen_cluster(nthreads=[]) async def test_retire_workers_empty(s): @@ -321,7 +558,7 @@ async def test_remove_worker_from_scheduler(s, a, b): assert a.address in s.stream_comms await s.remove_worker(address=a.address, stimulus_id="test") assert a.address not in s.workers - assert len(s.workers[b.address].processing) == len(dsk) # b owns everything + assert len(s.workers[b.address].processing) + len(s.queued) == len(dsk) @gen_cluster() @@ -605,12 +842,31 @@ async def test_ready_remove_worker(s, a, b): dependencies={"x-%d" % i: [] for i in range(20)}, ) - assert all(len(w.processing) > w.nthreads for w in s.workers.values()) + if s.WORKER_SATURATION == 1: + cmp = operator.eq + elif math.isinf(s.WORKER_SATURATION): + cmp = operator.gt + else: + pytest.fail(f"{s.WORKER_OVERSATURATION=}, must be 1 or inf") + + assert all(cmp(len(w.processing), w.nthreads) for w in s.workers.values()), ( + list(s.workers.values()), + s.WORKER_SATURATION, + ) + assert sum(len(w.processing) for w in s.workers.values()) + len(s.queued) == len( + s.tasks + ) await s.remove_worker(address=a.address, stimulus_id="test") assert set(s.workers) == {b.address} - assert all(len(w.processing) > w.nthreads for w in s.workers.values()) + assert all(cmp(len(w.processing), w.nthreads) for w in s.workers.values()), ( + list(s.workers.values()), + s.WORKER_SATURATION, + ) + assert sum(len(w.processing) for w in s.workers.values()) + len(s.queued) == len( + s.tasks + ) @gen_cluster(client=True, Worker=Nanny, timeout=60) @@ -1180,9 +1436,13 @@ async def test_learn_occupancy(c, s, a, b): while sum(len(ts.who_has) for ts in s.tasks.values()) < 10: await asyncio.sleep(0.01) - assert 100 < s.total_occupancy < 1000 + nproc = sum(ts.state == "processing" for ts in s.tasks.values()) + assert nproc * 0.1 < s.total_occupancy < nproc * 0.4 for w in [a, b]: - assert 50 < s.workers[w.address].occupancy < 700 + ws = s.workers[w.address] + occ = ws.occupancy + proc = len(ws.processing) + assert proc * 0.1 < occ < proc * 0.4 @pytest.mark.slow @@ -1193,7 +1453,8 @@ async def test_learn_occupancy_2(c, s, a, b): while not any(ts.who_has for ts in s.tasks.values()): await asyncio.sleep(0.01) - assert 100 < s.total_occupancy < 1000 + nproc = sum(ts.state == "processing" for ts in s.tasks.values()) + assert nproc * 0.1 < s.total_occupancy < nproc * 0.4 @gen_cluster(client=True) @@ -2011,6 +2272,31 @@ async def test_adaptive_target(c, s, a, b): assert s.adaptive_target(target_duration=".1s") == 0 +@pytest.mark.parametrize("queue", [True, False]) +@gen_cluster( + client=True, + nthreads=[], +) +async def test_adaptive_target_empty_cluster(c, s, queue): + if queue: + s.WORKER_SATURATION = 1.0 + else: + s.WORKER_SATURATION = float("inf") + + assert s.adaptive_target() == 0 + + f = c.submit(inc, -1) + await async_wait_for(lambda: s.tasks, timeout=5) + assert s.adaptive_target() == 1 + del f + + if queue: + # only queuing supports fast scale-up for empty clusters https://github.com/dask/distributed/issues/6962 + fs = c.map(inc, range(100)) + await async_wait_for(lambda: len(s.tasks) == len(fs), timeout=5) + assert s.adaptive_target() > 1 + + @gen_test() async def test_async_context_manager(): async with Scheduler(dashboard_address=":0") as s: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 0103808cf5..1fe036a371 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3116,6 +3116,11 @@ async def test_worker_status_sync(s, a): "prev-status": "paused", "status": "running", }, + { + "action": "worker-status-change", + "prev-status": "running", + "status": "closing_gracefully", + }, {"action": "remove-worker", "processing-tasks": {}}, {"action": "retired"}, ] diff --git a/docs/source/images/task-state.dot b/docs/source/images/task-state.dot index fde7d8c62c..f68008c6a2 100644 --- a/docs/source/images/task-state.dot +++ b/docs/source/images/task-state.dot @@ -7,7 +7,10 @@ digraph{ released2 [label=released]; released1 -> waiting; waiting -> processing; - waiting -> "no-worker" [dir=both]; + waiting -> "no-worker"; + waiting -> queued; + "no-worker" -> processing; + queued -> processing; processing -> memory; processing -> error; error -> released2; diff --git a/docs/source/images/task-state.svg b/docs/source/images/task-state.svg index 7b7d5c84ad..b8d23318be 100644 --- a/docs/source/images/task-state.svg +++ b/docs/source/images/task-state.svg @@ -1,109 +1,132 @@ - - + - + released1 - -released + +released waiting - -waiting + +waiting released1->waiting - - + + released2 - -released + +released - + forgotten - -forgotten + +forgotten - + released2->forgotten - - + + processing - -processing + +processing waiting->processing - - + + no-worker - -no-worker + +no-worker waiting->no-worker - - - + + - + +queued + +queued + + + +waiting->queued + + + + + memory - -memory + +memory - + processing->memory - - + + - + error - -error + +error - + processing->error - - + + + + + +no-worker->processing + + + + + +queued->processing + + - + memory->released2 - - + + - + error->released2 - - + + diff --git a/docs/source/scheduling-state.rst b/docs/source/scheduling-state.rst index 9354b8f6b7..3d8e435f20 100644 --- a/docs/source/scheduling-state.rst +++ b/docs/source/scheduling-state.rst @@ -52,7 +52,7 @@ Task State ---------- Internally, the scheduler moves tasks between a fixed set of states, -notably ``released``, ``waiting``, ``no-worker``, ``processing``, +notably ``released``, ``waiting``, ``no-worker``, ``queued``, ``processing``, ``memory``, ``error``. Tasks flow along the following states with the following allowed transitions: @@ -60,6 +60,8 @@ Tasks flow along the following states with the following allowed transitions: .. image:: images/task-state.svg :alt: Dask scheduler task states +Note that tasks may also transition to ``released`` from any state (not shown on diagram). + released Known but not actively computing or in memory waiting @@ -67,6 +69,8 @@ waiting no-worker Ready to be computed, but no appropriate worker exists (for example because of resource restrictions, or because no worker is connected at all). +queued + Ready to be computed, but all workers are already full. processing All dependencies are available and the task is assigned to a worker for compute (the scheduler doesn't know whether it's in a worker queue or actively being computed). @@ -80,12 +84,18 @@ forgotten dereferenced from the scheduler. .. note:: - There's no intermediate state between ``waiting`` / ``no-worker`` and + When the ``distributed.scheduler.worker_saturation`` config value is set to ``inf`` + (default), there's no intermediate state between ``waiting`` / ``no-worker`` and ``processing``: as soon as a task has all of its dependencies in memory somewhere on the cluster, it is immediately assigned to a worker. This can lead to very long task queues on the workers, which are then rebalanced dynamically through :doc:`work-stealing`. + Setting ``distributed.scheduler.worker_saturation`` to ``1.0`` (or any finite value) + will instead queue excess root tasks on the scheduler in the ``queued`` state. These + tasks are only assigned to workers when they have capacity for them, reducing the + length of task queues on the workers. + In addition to the literal state, though, other information needs to be kept and updated about each task. Individual task state is stored in an object named :class:`TaskState`; see full API through the link.