diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py
index 6663ebacd9..d89483472d 100644
--- a/distributed/dashboard/components/scheduler.py
+++ b/distributed/dashboard/components/scheduler.py
@@ -2162,8 +2162,8 @@ def __init__(self, scheduler, **kwargs):
node_colors = factor_cmap(
"state",
- factors=["waiting", "processing", "memory", "released", "erred"],
- palette=["gray", "green", "red", "blue", "black"],
+ factors=["waiting", "queued", "processing", "memory", "released", "erred"],
+ palette=["gray", "yellow", "green", "red", "blue", "black"],
)
self.root = figure(title="Task Graph", **kwargs)
@@ -3051,7 +3051,7 @@ def __init__(self, scheduler, **kwargs):
self.scheduler = scheduler
data = progress_quads(
- dict(all={}, memory={}, erred={}, released={}, processing={})
+ dict(all={}, memory={}, erred={}, released={}, processing={}, queued={})
)
self.source = ColumnDataSource(data=data)
@@ -3123,6 +3123,18 @@ def __init__(self, scheduler, **kwargs):
fill_alpha=0.35,
line_alpha=0,
)
+ self.root.quad(
+ source=self.source,
+ top="top",
+ bottom="bottom",
+ left="processing-loc",
+ right="queued-loc",
+ fill_color="gray",
+ hatch_pattern="/",
+ hatch_color="white",
+ fill_alpha=0.35,
+ line_alpha=0,
+ )
self.root.text(
source=self.source,
text="show-name",
@@ -3158,6 +3170,14 @@ def __init__(self, scheduler, **kwargs):
All:
@all
+
+ Queued:
+ @queued
+
+
+ Processing:
+ @processing
+
Memory:
@memory
@@ -3166,10 +3186,6 @@ def __init__(self, scheduler, **kwargs):
Erred:
@erred
-
- Ready:
- @processing
-
""",
)
self.root.add_tools(hover)
@@ -3183,6 +3199,7 @@ def update(self):
"released": {},
"processing": {},
"waiting": {},
+ "queued": {},
}
for tp in self.scheduler.task_prefixes.values():
@@ -3193,6 +3210,7 @@ def update(self):
state["released"][tp.name] = active_states["released"]
state["processing"][tp.name] = active_states["processing"]
state["waiting"][tp.name] = active_states["waiting"]
+ state["queued"][tp.name] = active_states["queued"]
state["all"] = {k: sum(v[k] for v in state.values()) for k in state["memory"]}
@@ -3205,7 +3223,7 @@ def update(self):
totals = {
k: sum(state[k].values())
- for k in ["all", "memory", "erred", "released", "waiting"]
+ for k in ["all", "memory", "erred", "released", "waiting", "queued"]
}
totals["processing"] = totals["all"] - sum(
v for k, v in totals.items() if k != "all"
@@ -3213,8 +3231,10 @@ def update(self):
self.root.title.text = (
"Progress -- total: %(all)s, "
- "in-memory: %(memory)s, processing: %(processing)s, "
"waiting: %(waiting)s, "
+ "queued: %(queued)s, "
+ "processing: %(processing)s, "
+ "in-memory: %(memory)s, "
"erred: %(erred)s" % totals
)
diff --git a/distributed/diagnostics/progress_stream.py b/distributed/diagnostics/progress_stream.py
index 03aedde8f6..764aef790e 100644
--- a/distributed/diagnostics/progress_stream.py
+++ b/distributed/diagnostics/progress_stream.py
@@ -17,7 +17,7 @@ def counts(scheduler, allprogress):
{"all": valmap(len, allprogress.all), "nbytes": allprogress.nbytes},
{
state: valmap(len, allprogress.state[state])
- for state in ["memory", "erred", "released", "processing"]
+ for state in ["memory", "erred", "released", "processing", "queued"]
},
)
@@ -66,23 +66,29 @@ def progress_quads(msg, nrows=8, ncols=3):
... 'memory': {'inc': 2, 'dec': 0, 'add': 1},
... 'erred': {'inc': 0, 'dec': 1, 'add': 0},
... 'released': {'inc': 1, 'dec': 0, 'add': 1},
- ... 'processing': {'inc': 1, 'dec': 0, 'add': 2}}
+ ... 'processing': {'inc': 1, 'dec': 0, 'add': 2},
+ ... 'queued': {'inc': 1, 'dec': 0, 'add': 2}}
>>> progress_quads(msg, nrows=2) # doctest: +SKIP
- {'name': ['inc', 'add', 'dec'],
- 'left': [0, 0, 1],
- 'right': [0.9, 0.9, 1.9],
- 'top': [0, -1, 0],
- 'bottom': [-.8, -1.8, -.8],
- 'released': [1, 1, 0],
- 'memory': [2, 1, 0],
- 'erred': [0, 0, 1],
- 'processing': [1, 0, 2],
- 'done': ['3 / 5', '2 / 4', '1 / 1'],
- 'released-loc': [.2/.9, .25 / 0.9, 1],
- 'memory-loc': [3 / 5 / .9, .5 / 0.9, 1],
- 'erred-loc': [3 / 5 / .9, .5 / 0.9, 1.9],
- 'processing-loc': [4 / 5, 1 / 1, 1]}}
+ {'all': [5, 4, 1],
+ 'memory': [2, 1, 0],
+ 'erred': [0, 0, 1],
+ 'released': [1, 1, 0],
+ 'processing': [1, 2, 0],
+ 'queued': [1, 2, 0],
+ 'name': ['inc', 'add', 'dec'],
+ 'show-name': ['inc', 'add', 'dec'],
+ 'left': [0, 0, 1],
+ 'right': [0.9, 0.9, 1.9],
+ 'top': [0, -1, 0],
+ 'bottom': [-0.8, -1.8, -0.8],
+ 'color': ['#45BF6F', '#2E6C8E', '#440154'],
+ 'released-loc': [0.18, 0.225, 1.0],
+ 'memory-loc': [0.54, 0.45, 1.0],
+ 'erred-loc': [0.54, 0.45, 1.9],
+ 'processing-loc': [0.72, 0.9, 1.9],
+ 'queued-loc': [0.9, 1.35, 1.9],
+ 'done': ['3 / 5', '2 / 4', '1 / 1']}
"""
width = 0.9
names = sorted(msg["all"], key=msg["all"].get, reverse=True)
@@ -102,19 +108,28 @@ def progress_quads(msg, nrows=8, ncols=3):
d["memory-loc"] = []
d["erred-loc"] = []
d["processing-loc"] = []
+ d["queued-loc"] = []
d["done"] = []
- for r, m, e, p, a, l in zip(
- d["released"], d["memory"], d["erred"], d["processing"], d["all"], d["left"]
+ for r, m, e, p, q, a, l in zip(
+ d["released"],
+ d["memory"],
+ d["erred"],
+ d["processing"],
+ d["queued"],
+ d["all"],
+ d["left"],
):
rl = width * r / a + l
ml = width * (r + m) / a + l
el = width * (r + m + e) / a + l
pl = width * (p + r + m + e) / a + l
+ ql = width * (p + r + m + e + q) / a + l
done = "%d / %d" % (r + m + e, a)
d["released-loc"].append(rl)
d["memory-loc"].append(ml)
d["erred-loc"].append(el)
d["processing-loc"].append(pl)
+ d["queued-loc"].append(ql)
d["done"].append(done)
return d
diff --git a/distributed/diagnostics/tests/test_progress_stream.py b/distributed/diagnostics/tests/test_progress_stream.py
index 73a5be81ab..49c93b212e 100644
--- a/distributed/diagnostics/tests/test_progress_stream.py
+++ b/distributed/diagnostics/tests/test_progress_stream.py
@@ -18,6 +18,7 @@ def test_progress_quads():
"erred": {"inc": 0, "dec": 1, "add": 0},
"released": {"inc": 1, "dec": 0, "add": 1},
"processing": {"inc": 1, "dec": 0, "add": 2},
+ "queued": {"inc": 1, "dec": 0, "add": 2},
}
d = progress_quads(msg, nrows=2)
@@ -35,11 +36,13 @@ def test_progress_quads():
"memory": [2, 1, 0],
"erred": [0, 0, 1],
"processing": [1, 2, 0],
+ "queued": [1, 2, 0],
"done": ["3 / 5", "2 / 4", "1 / 1"],
"released-loc": [0.9 * 1 / 5, 0.25 * 0.9, 1.0],
"memory-loc": [0.9 * 3 / 5, 0.5 * 0.9, 1.0],
"erred-loc": [0.9 * 3 / 5, 0.5 * 0.9, 1.9],
"processing-loc": [0.9 * 4 / 5, 1 * 0.9, 1 * 0.9 + 1],
+ "queued-loc": [1 * 0.9, 1.5 * 0.9, 1 * 0.9 + 1],
}
assert d == expected
@@ -52,6 +55,7 @@ def test_progress_quads_too_many():
"erred": {k: 0 for k in keys},
"released": {k: 0 for k in keys},
"processing": {k: 0 for k in keys},
+ "queued": {k: 0 for k in keys},
}
d = progress_quads(msg, nrows=6, ncols=3)
@@ -78,6 +82,7 @@ async def test_progress_stream(c, s, a, b):
"memory": {"div": 9, "inc": 1},
"released": {"inc": 4},
"processing": {},
+ "queued": {},
}
assert set(nbytes) == set(msg["all"])
assert all(v > 0 for v in nbytes.values())
@@ -95,6 +100,7 @@ def test_progress_quads_many_functions():
"erred": {fn: 0 for fn in funcnames},
"released": {fn: 0 for fn in funcnames},
"processing": {fn: 0 for fn in funcnames},
+ "queued": {fn: 0 for fn in funcnames},
}
d = progress_quads(msg, nrows=2)
diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml
index 7b5a7ed53e..a110aaa807 100644
--- a/distributed/distributed-schema.yaml
+++ b/distributed/distributed-schema.yaml
@@ -117,6 +117,28 @@ properties:
description: |
How frequently to balance worker loads
+ worker-saturation:
+ type: number
+ exclusiveMinimum: 0
+ description: |
+ Controls how many root tasks are sent to workers (like a `readahead`).
+
+ Up to worker-saturation * nthreads root tasks are sent to a
+ worker at a time. If `.inf`, all runnable tasks are immediately sent to workers.
+
+ Allowing oversaturation (> 1.0) means a worker may start running a new root task as
+ soon as it completes the previous, even if there is a higher-priority downstream task
+ to run. This reduces worker idleness, by letting workers do something while waiting for
+ further instructions from the scheduler, even if it's not the most efficient
+ thing to do.
+
+ This generally comes at the expense of increased memory usage. It leads to "wider"
+ (more breadth-first) execution of the graph.
+
+ Compute-bound workloads may benefit from oversaturation. Memory-bound workloads should
+ generally leave `worker-saturation` at 1.0, though 1.25-1.5 could slightly improve
+ performance if ample memory is available.
+
worker-ttl:
type:
- string
diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml
index 400058148f..ca8292146f 100644
--- a/distributed/distributed.yaml
+++ b/distributed/distributed.yaml
@@ -22,6 +22,7 @@ distributed:
events-log-length: 100000
work-stealing: True # workers should steal tasks from each other
work-stealing-interval: 100ms # Callback time for work stealing
+ worker-saturation: .inf # Send this fraction of nthreads root tasks to workers
worker-ttl: "5 minutes" # like '60s'. Time to live for workers. They must heartbeat faster than this
pickle: True # Is the scheduler allowed to deserialize arbitrary bytestrings
preload: [] # Run custom modules with Scheduler
diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py
index c2d75d8206..158384b758 100644
--- a/distributed/http/scheduler/tests/test_scheduler_http.py
+++ b/distributed/http/scheduler/tests/test_scheduler_http.py
@@ -137,7 +137,15 @@ async def fetch_metrics():
]
return active_metrics, forgotten_tasks
- expected = {"memory", "released", "processing", "waiting", "no-worker", "erred"}
+ expected = {
+ "memory",
+ "released",
+ "queued",
+ "processing",
+ "waiting",
+ "no-worker",
+ "erred",
+ }
# Ensure that we get full zero metrics for all states even though the
# scheduler did nothing, yet
diff --git a/distributed/scheduler.py b/distributed/scheduler.py
index 10ca56e446..8ffb215025 100644
--- a/distributed/scheduler.py
+++ b/distributed/scheduler.py
@@ -58,6 +58,7 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
+from distributed.collections import HeapSet
from distributed.comm import (
Comm,
CommClosedError,
@@ -116,6 +117,7 @@
"released",
"waiting",
"no-worker",
+ "queued",
"processing",
"memory",
"erred",
@@ -133,6 +135,7 @@
"released",
"waiting",
"no-worker",
+ "queued",
"processing",
"memory",
"erred",
@@ -844,14 +847,6 @@ class TaskGroup:
#: The result types of this TaskGroup
types: set[str]
- #: The worker most recently assigned a task from this group, or None when the group
- #: is not identified to be root-like by `SchedulerState.decide_worker`.
- last_worker: WorkerState | None
-
- #: If `last_worker` is not None, the number of times that worker should be assigned
- #: subsequent tasks until a new worker is chosen.
- last_worker_tasks_left: int
-
prefix: TaskPrefix | None
start: float
stop: float
@@ -870,8 +865,6 @@ def __init__(self, name: str):
self.start = 0.0
self.stop = 0.0
self.all_durations = defaultdict(float)
- self.last_worker = None
- self.last_worker_tasks_left = 0
def add_duration(self, action: str, start: float, stop: float) -> None:
duration = stop - start
@@ -1281,13 +1274,15 @@ class SchedulerState:
Tasks currently known to the scheduler
* **unrunnable:** ``{TaskState}``
Tasks in the "no-worker" state
+ * **queued:** ``HeapSet[TaskState]``
+ Tasks in the "queued" state, ordered by priority
* **workers:** ``{worker key: WorkerState}``
Workers currently connected to the scheduler
* **idle:** ``{WorkerState}``:
- Set of workers that are not fully utilized
+ Set of workers that are currently in running state and not fully utilized
* **saturated:** ``{WorkerState}``:
- Set of workers that are not over-utilized
+ Set of workers that are fully utilized. May include non-running workers.
* **running:** ``{WorkerState}``:
Set of workers that are currently in running state
@@ -1307,7 +1302,10 @@ class SchedulerState:
"extensions",
"host_info",
"idle",
+ "last_root_worker",
+ "last_root_worker_tasks_left",
"n_tasks",
+ "queued",
"resources",
"saturated",
"running",
@@ -1332,6 +1330,7 @@ class SchedulerState:
"MEMORY_REBALANCE_SENDER_MIN",
"MEMORY_REBALANCE_RECIPIENT_MAX",
"MEMORY_REBALANCE_HALF_GAP",
+ "WORKER_SATURATION",
}
def __init__(
@@ -1343,6 +1342,7 @@ def __init__(
resources: dict,
tasks: dict,
unrunnable: set,
+ queued: HeapSet[TaskState],
validate: bool,
plugins: Iterable[SchedulerPlugin] = (),
transition_counter_max: int | Literal[False] = False,
@@ -1375,10 +1375,13 @@ def __init__(
self.total_nthreads = 0
self.total_occupancy = 0.0
self.unknown_durations: dict[str, set[TaskState]] = {}
+ self.last_root_worker: WorkerState | None = None
+ self.last_root_worker_tasks_left: int = 0
+ self.queued = queued
self.unrunnable = unrunnable
self.validate = validate
self.workers = workers
- self.running = {
+ self.running: set[WorkerState] = {
ws for ws in self.workers.values() if ws.status == Status.running
}
self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins}
@@ -1403,6 +1406,9 @@ def __init__(
dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap")
/ 2.0
)
+ self.WORKER_SATURATION = dask.config.get(
+ "distributed.scheduler.worker-saturation"
+ )
self.transition_counter = 0
self._idle_transition_counter = 0
self.transition_counter_max = transition_counter_max
@@ -1418,6 +1424,7 @@ def __pdict__(self):
"resources": self.resources,
"saturated": self.saturated,
"unrunnable": self.unrunnable,
+ "queued": self.queued,
"n_tasks": self.n_tasks,
"unknown_durations": self.unknown_durations,
"validate": self.validate,
@@ -1576,11 +1583,11 @@ def _transition(
if not stimulus_id:
stimulus_id = STIMULUS_ID_UNSET
- finish2 = ts._state
+ actual_finish = ts._state
# FIXME downcast antipattern
scheduler = cast(Scheduler, self)
scheduler.transition_log.append(
- (key, start, finish2, recommendations, stimulus_id, time())
+ (key, start, actual_finish, recommendations, stimulus_id, time())
)
if self.validate:
if stimulus_id == STIMULUS_ID_UNSET:
@@ -1591,8 +1598,8 @@ def _transition(
"Transitioned %r %s->%s (actual: %s). Consequence: %s",
key,
start,
- finish2,
- ts.state,
+ finish,
+ actual_finish,
dict(recommendations),
)
if self.plugins:
@@ -1603,7 +1610,7 @@ def _transition(
self.tasks[ts.key] = ts
for plugin in list(self.plugins.values()):
try:
- plugin.transition(key, start, finish2, *args, **kwargs)
+ plugin.transition(key, start, actual_finish, *args, **kwargs)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)
if ts.state == "forgotten":
@@ -1707,11 +1714,8 @@ def transition_released_waiting(self, key, stimulus_id):
ts.waiters = {dts for dts in ts.dependents if dts.state == "waiting"}
if not ts.waiting_on:
- if self.workers:
- recommendations[key] = "processing"
- else:
- self.unrunnable.add(ts)
- ts.state = "no-worker"
+ # NOTE: waiting->processing will send tasks to queued or no-worker as necessary
+ recommendations[key] = "processing"
return recommendations, client_msgs, worker_msgs
except Exception as e:
@@ -1722,43 +1726,21 @@ def transition_released_waiting(self, key, stimulus_id):
pdb.set_trace()
raise
- def transition_no_worker_waiting(self, key, stimulus_id):
+ def transition_no_worker_processing(self, key, stimulus_id):
try:
ts: TaskState = self.tasks[key]
- dts: TaskState
- recommendations: dict = {}
+ recommendations: Recs = {}
client_msgs: dict = {}
worker_msgs: dict = {}
if self.validate:
+ assert not ts.actor, f"Actors can't be in `no-worker`: {ts}"
assert ts in self.unrunnable
- assert not ts.waiting_on
- assert not ts.who_has
- assert not ts.processing_on
-
- self.unrunnable.remove(ts)
-
- if ts.has_lost_dependencies:
- recommendations[key] = "forgotten"
- return recommendations, client_msgs, worker_msgs
- for dts in ts.dependencies:
- dep = dts.key
- if not dts.who_has:
- ts.waiting_on.add(dts)
- if dts.state == "released":
- recommendations[dep] = "waiting"
- else:
- dts.waiters.add(ts)
-
- ts.state = "waiting"
-
- if not ts.waiting_on:
- if self.workers:
- recommendations[key] = "processing"
- else:
- self.unrunnable.add(ts)
- ts.state = "no-worker"
+ if ws := self.decide_worker_non_rootish(ts):
+ self.unrunnable.discard(ts)
+ worker_msgs = _add_to_processing(self, ts, ws)
+ # If no worker, task just stays in `no-worker`
return recommendations, client_msgs, worker_msgs
except Exception as e:
@@ -1811,70 +1793,153 @@ def transition_no_worker_memory(
pdb.set_trace()
raise
- def decide_worker(self, ts: TaskState) -> WorkerState | None:
- """
- Decide on a worker for task *ts*. Return a WorkerState.
+ def decide_worker_rootish_queuing_disabled(
+ self, ts: TaskState
+ ) -> WorkerState | None:
+ """Pick a worker for a runnable root-ish task, without queuing.
- If it's a root or root-like task, we place it with its relatives to
- reduce future data tansfer.
+ This attempts to schedule sibling tasks on the same worker, reducing future data
+ transfer. It does not consider the location of dependencies, since they'll end
+ up on every worker anyway.
- If it has dependencies or restrictions, we use
- `decide_worker_from_deps_and_restrictions`.
+ It assumes it's being called on a batch of tasks in priority order, and
+ maintains state in `SchedulerState.last_root_worker` and
+ `SchedulerState.last_root_worker_tasks_left` to achieve this.
- Otherwise, we pick the least occupied worker, or pick from all workers
- in a round-robin fashion.
- """
- if not self.workers:
- return None
+ This will send every runnable task to a worker, often causing root task
+ overproduction.
- tg = ts.group
- valid_workers = self.valid_workers(ts)
+ Returns
+ -------
+ ws: WorkerState | None
+ The worker to assign the task to. If there are no workers in the cluster,
+ returns None, in which case the task should be transitioned to
+ ``no-worker``.
+ """
+ if self.validate:
+ # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
+ assert math.isinf(self.WORKER_SATURATION)
- if (
- valid_workers is not None
- and not valid_workers
- and not ts.loose_restrictions
- ):
- self.unrunnable.add(ts)
- ts.state = "no-worker"
+ pool = self.idle.values() if self.idle else self.running
+ if not pool:
return None
- # Group is larger than cluster with few dependencies?
- # Minimize future data transfers.
- if (
- valid_workers is None
- and len(tg) > self.total_nthreads * 2
- and len(tg.dependencies) < 5
- and sum(map(len, tg.dependencies)) < 5
+ lws = self.last_root_worker
+ if not (
+ lws
+ and self.last_root_worker_tasks_left
+ and self.workers.get(lws.address) is lws
):
- ws = tg.last_worker
+ # Last-used worker is full or unknown; pick a new worker for the next few tasks
+ ws = self.last_root_worker = min(
+ pool, key=lambda ws: len(ws.processing) / ws.nthreads
+ )
+ # TODO better batching metric (`len(tg)` is not necessarily the total number of root tasks!)
+ self.last_root_worker_tasks_left = math.floor(
+ (len(ts.group) / self.total_nthreads) * ws.nthreads
+ )
+ else:
+ ws = lws
- if not (ws and tg.last_worker_tasks_left and ws.address in self.workers):
- # Last-used worker is full or unknown; pick a new worker for the next few tasks
- ws = min(
- (self.idle or self.workers).values(),
- key=partial(self.worker_objective, ts),
- )
- assert ws
- tg.last_worker_tasks_left = math.floor(
- (len(tg) / self.total_nthreads) * ws.nthreads
- )
+ self.last_root_worker_tasks_left -= 1
+
+ if self.validate and ws is not None:
+ assert self.workers.get(ws.address) is ws
+ assert ws in self.running, (ws, self.running)
+
+ return ws
+
+ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:
+ """Pick a worker for a runnable root-ish task, if not all are busy.
+
+ Picks the least-busy worker out of the ``idle`` workers (idle workers have fewer
+ tasks running than threads, as set by ``distributed.scheduler.worker-saturation``).
+ It does not consider the location of dependencies, since they'll end up on every
+ worker anyway.
+
+ If all workers are full, returns None, meaning the task should transition to
+ ``queued``. The scheduler will wait to send it to a worker until a thread opens
+ up. This ensures that downstream tasks always run before new root tasks are
+ started.
+
+ This does not try to schedule sibling tasks on the same worker; in fact, it
+ usually does the opposite. Even though this increases subsequent data transfer,
+ it typically reduces overall memory use by eliminating root task overproduction.
+
+ Returns
+ -------
+ ws: WorkerState | None
+ The worker to assign the task to. If there are no idle workers, returns
+ None, in which case the task should be transitioned to ``queued``.
+
+ """
+ if self.validate:
+ # We don't `assert self.is_rootish(ts)` here, because that check is dependent on
+ # cluster size. It's possible a task looked root-ish when it was queued, but the
+ # cluster has since scaled up and it no longer does when coming out of the queue.
+ # If `is_rootish` changes to a static definition, then add that assertion here
+ # (and actually pass in the task).
+ assert not math.isinf(self.WORKER_SATURATION)
+
+ if not self.idle:
+ # All workers busy? Task gets/stays queued.
+ return None
- # Record `last_worker`, or clear it on the final task
- tg.last_worker = (
- ws if tg.states["released"] + tg.states["waiting"] > 1 else None
+ # Just pick the least busy worker.
+ # NOTE: this will lead to worst-case scheduling with regards to co-assignment.
+ ws = min(self.idle.values(), key=lambda ws: len(ws.processing) / ws.nthreads)
+ if self.validate:
+ assert not _worker_full(ws, self.WORKER_SATURATION), (
+ ws,
+ _task_slots_available(ws, self.WORKER_SATURATION),
)
- tg.last_worker_tasks_left -= 1
- return ws
+ assert ws in self.running, (ws, self.running)
+
+ if self.validate and ws is not None:
+ assert self.workers.get(ws.address) is ws
+ assert ws in self.running, (ws, self.running)
+
+ return ws
+
+ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
+ """Pick a worker for a runnable non-root task, considering dependencies and restrictions.
+
+ Out of eligible workers holding dependencies of ``ts``, selects the worker
+ where, considering worker backlong and data-transfer costs, the task is
+ estimated to start running the soonest.
+
+ Returns
+ -------
+ ws: WorkerState | None
+ The worker to assign the task to. If no workers satisfy the restrictions of
+ ``ts`` or there are no running workers, returns None, in which case the task
+ should be transitioned to ``no-worker``.
+ """
+ if not self.running:
+ return None
+
+ valid_workers = self.valid_workers(ts)
+ if valid_workers is None and len(self.running) < len(self.workers):
+ if not self.running:
+ return None
+
+ # If there were no restrictions, `valid_workers()` didn't subset by `running`.
+ valid_workers = self.running
if ts.dependencies or valid_workers is not None:
ws = decide_worker(
ts,
- self.workers.values(),
+ self.running,
valid_workers,
partial(self.worker_objective, ts),
)
else:
+ # TODO if `is_rootish` would always return True for tasks without dependencies,
+ # we could remove all this logic. The rootish assignment logic would behave
+ # more or less the same as this, maybe without gauranteed round-robin though?
+ # This path is only reachable when `ts` doesn't have dependencies, but its
+ # group is also smaller than the cluster.
+
# Fastpath when there are no related tasks or restrictions
worker_pool = self.idle or self.workers
wp_vals = worker_pool.values()
@@ -1898,46 +1963,36 @@ def decide_worker(self, ts: TaskState) -> WorkerState | None:
ws = wp_vals[self.n_tasks % n_workers]
if self.validate and ws is not None:
- assert ws.address in self.workers
+ assert self.workers.get(ws.address) is ws
+ assert ws in self.running, (ws, self.running)
return ws
def transition_waiting_processing(self, key, stimulus_id):
+ """Possibly schedule a ready task. This is the primary dispatch for ready tasks.
+
+ If there's no appropriate worker for the task (but the task is otherwise runnable),
+ it will be recommended to ``no-worker`` or ``queued``.
+ """
try:
ts: TaskState = self.tasks[key]
- dts: TaskState
- recommendations: dict = {}
- client_msgs: dict = {}
- worker_msgs: dict = {}
- if self.validate:
- assert not ts.waiting_on
- assert not ts.who_has
- assert not ts.exception_blame
- assert not ts.processing_on
- assert not ts.has_lost_dependencies
- assert ts not in self.unrunnable
- assert all(dts.who_has for dts in ts.dependencies)
-
- ws = self.decide_worker(ts)
- if ws is None:
- return recommendations, client_msgs, worker_msgs
- worker = ws.address
-
- self._set_duration_estimate(ts, ws)
- ts.processing_on = ws
- ts.state = "processing"
- self.acquire_resources(ts, ws)
- self.check_idle_saturated(ws)
- self.n_tasks += 1
- if ts.actor:
- ws.actors.add(ts)
-
- # logger.debug("Send job to worker: %s, %s", worker, key)
-
- worker_msgs[worker] = [_task_to_msg(self, ts)]
+ if self.is_rootish(ts):
+ # NOTE: having two root-ish methods is temporary. When the feature flag is removed,
+ # there should only be one, which combines co-assignment and queuing.
+ # Eventually, special-casing root tasks might be removed entirely, with better heuristics.
+ if math.isinf(self.WORKER_SATURATION):
+ if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
+ return {ts.key: "no-worker"}, {}, {}
+ else:
+ if not (ws := self.decide_worker_rootish_queuing_enabled()):
+ return {ts.key: "queued"}, {}, {}
+ else:
+ if not (ws := self.decide_worker_non_rootish(ts)):
+ return {ts.key: "no-worker"}, {}, {}
- return recommendations, client_msgs, worker_msgs
+ worker_msgs = _add_to_processing(self, ts, ws)
+ return {}, {}, worker_msgs
except Exception as e:
logger.exception(e)
if LOG_PDB:
@@ -2070,7 +2125,10 @@ def transition_processing_memory(
if nbytes is not None:
ts.set_nbytes(nbytes)
- _remove_from_processing(self, ts)
+ # NOTE: recommendations for queued tasks are added first, so they'll be popped last,
+ # allowing higher-priority downstream tasks to be transitioned first.
+ # FIXME: this would be incorrect if queued tasks are user-annotated as higher priority.
+ _exit_processing_common(self, ts, recommendations)
_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
@@ -2292,7 +2350,7 @@ def transition_waiting_released(self, key, stimulus_id):
def transition_processing_released(self, key: str, stimulus_id: str):
try:
ts = self.tasks[key]
- recommendations = {}
+ recommendations: Recs = {}
worker_msgs = {}
if self.validate:
@@ -2301,7 +2359,7 @@ def transition_processing_released(self, key: str, stimulus_id: str):
assert not ts.waiting_on
assert ts.state == "processing"
- ws = _remove_from_processing(self, ts)
+ ws = _exit_processing_common(self, ts, recommendations)
if ws:
worker_msgs[ws.address] = [
{
@@ -2311,24 +2369,7 @@ def transition_processing_released(self, key: str, stimulus_id: str):
}
]
- ts.state = "released"
-
- if ts.has_lost_dependencies:
- recommendations[key] = "forgotten"
- elif ts.waiters or ts.who_wants:
- recommendations[key] = "waiting"
-
- if recommendations.get(key) != "waiting":
- for dts in ts.dependencies:
- if dts.state != "released":
- dts.waiters.discard(ts)
- if not dts.waiters and not dts.who_wants:
- recommendations[dts.key] = "released"
- ts.waiters.clear()
-
- if self.validate:
- assert not ts.processing_on
-
+ _propagage_released(self, ts, recommendations)
return recommendations, {}, worker_msgs
except Exception as e:
logger.exception(e)
@@ -2395,7 +2436,7 @@ def transition_processing_erred(
ws = ts.processing_on
ws.actors.remove(ts)
- _remove_from_processing(self, ts)
+ _exit_processing_common(self, ts, recommendations)
ts.erred_on.add(worker)
if exception is not None:
@@ -2494,6 +2535,99 @@ def transition_no_worker_released(self, key, stimulus_id):
pdb.set_trace()
raise
+ def transition_waiting_queued(self, key, stimulus_id):
+ try:
+ ts: TaskState = self.tasks[key]
+ recommendations: Recs = {}
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+
+ if self.validate:
+ assert not self.idle, (ts, self.idle)
+ _validate_ready(self, ts)
+
+ ts.state = "queued"
+ self.queued.add(ts)
+
+ return recommendations, client_msgs, worker_msgs
+ except Exception as e:
+ logger.exception(e)
+ if LOG_PDB:
+ import pdb
+
+ pdb.set_trace()
+ raise
+
+ def transition_waiting_no_worker(self, key, stimulus_id):
+ try:
+ ts: TaskState = self.tasks[key]
+ recommendations: Recs = {}
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+
+ if self.validate:
+ _validate_ready(self, ts)
+
+ ts.state = "no-worker"
+ self.unrunnable.add(ts)
+
+ return recommendations, client_msgs, worker_msgs
+ except Exception as e:
+ logger.exception(e)
+ if LOG_PDB:
+ import pdb
+
+ pdb.set_trace()
+ raise
+
+ def transition_queued_released(self, key, stimulus_id):
+ try:
+ ts: TaskState = self.tasks[key]
+ recommendations: Recs = {}
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+
+ if self.validate:
+ assert ts in self.queued
+ assert not ts.processing_on
+
+ self.queued.remove(ts)
+
+ _propagage_released(self, ts, recommendations)
+ return recommendations, client_msgs, worker_msgs
+ except Exception as e:
+ logger.exception(e)
+ if LOG_PDB:
+ import pdb
+
+ pdb.set_trace()
+ raise
+
+ def transition_queued_processing(self, key, stimulus_id):
+ try:
+ ts: TaskState = self.tasks[key]
+ recommendations: Recs = {}
+ client_msgs: dict = {}
+ worker_msgs: dict = {}
+
+ if self.validate:
+ assert not ts.actor, f"Actors can't be queued: {ts}"
+ assert ts in self.queued
+
+ if ws := self.decide_worker_rootish_queuing_enabled():
+ self.queued.discard(ts)
+ worker_msgs = _add_to_processing(self, ts, ws)
+ # If no worker, task just stays `queued`
+
+ return recommendations, client_msgs, worker_msgs
+ except Exception as e:
+ logger.exception(e)
+ if LOG_PDB:
+ import pdb
+
+ pdb.set_trace()
+ raise
+
def _remove_key(self, key):
ts: TaskState = self.tasks.pop(key)
assert ts.state == "forgotten"
@@ -2559,6 +2693,7 @@ def transition_released_forgotten(self, key, stimulus_id):
assert ts.state in ("released", "erred")
assert not ts.who_has
assert not ts.processing_on
+ assert ts not in self.queued
assert not ts.waiting_on, (ts, ts.waiting_on)
if not ts.run_spec:
# It's ok to forget a pure data task
@@ -2601,12 +2736,16 @@ def transition_released_forgotten(self, key, stimulus_id):
("released", "waiting"): transition_released_waiting,
("waiting", "released"): transition_waiting_released,
("waiting", "processing"): transition_waiting_processing,
+ ("waiting", "no-worker"): transition_waiting_no_worker,
+ ("waiting", "queued"): transition_waiting_queued,
("waiting", "memory"): transition_waiting_memory,
+ ("queued", "released"): transition_queued_released,
+ ("queued", "processing"): transition_queued_processing,
("processing", "released"): transition_processing_released,
("processing", "memory"): transition_processing_memory,
("processing", "erred"): transition_processing_erred,
("no-worker", "released"): transition_no_worker_released,
- ("no-worker", "waiting"): transition_no_worker_waiting,
+ ("no-worker", "processing"): transition_no_worker_processing,
("no-worker", "memory"): transition_no_worker_memory,
("released", "forgotten"): transition_released_forgotten,
("memory", "forgotten"): transition_memory_forgotten,
@@ -2619,6 +2758,23 @@ def transition_released_forgotten(self, key, stimulus_id):
# Assigning Tasks to Workers #
##############################
+ def is_rootish(self, ts: TaskState) -> bool:
+ """
+ Whether ``ts`` is a root or root-like task.
+
+ Root-ish tasks are part of a group that's much larger than the cluster,
+ and have few or no dependencies.
+ """
+ if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
+ return False
+ tg = ts.group
+ # TODO short-circuit to True if `not ts.dependencies`?
+ return (
+ len(tg) > self.total_nthreads * 2
+ and len(tg.dependencies) < 5
+ and sum(map(len, tg.dependencies)) < 5
+ )
+
def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None:
"""Estimate task duration using worker state and task state.
@@ -2659,6 +2815,15 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
all of their threads, and if the expected runtime of those tasks is
large enough.
+ If ``distributed.scheduler.worker-saturation`` is not ``inf``
+ (scheduler-side queuing is enabled), they are considered idle
+ if they have fewer tasks processing than the ``worker-saturation``
+ threshold dictates.
+
+ Otherwise, they are considered idle if they have fewer tasks processing
+ than threads, or if their tasks' total expected runtime is less than half
+ the expected runtime of the same number of average tasks.
+
This is useful for load balancing and adaptivity.
"""
if self.total_nthreads == 0 or ws.status == Status.closed:
@@ -2672,8 +2837,13 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
idle = self.idle
saturated = self.saturated
- if p < nc or occ < nc * avg / 2:
- idle[ws.address] = ws
+ if (
+ (p < nc or occ < nc * avg / 2)
+ if math.isinf(self.WORKER_SATURATION)
+ else not _worker_full(ws, self.WORKER_SATURATION)
+ ):
+ if ws.status == Status.running:
+ idle[ws.address] = ws
saturated.discard(ws)
else:
idle.pop(ws.address, None)
@@ -2730,7 +2900,10 @@ def get_task_duration(self, ts: TaskState) -> float:
def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
"""Return set of currently valid workers for key
- If all workers are valid then this returns ``None``.
+ If all workers are valid then this returns ``None``, in which case
+ any *running* worker can be used.
+ Otherwise, the subset of running workers valid for this task
+ is returned.
This checks tracks the following state:
* worker_restrictions
@@ -2779,10 +2952,7 @@ def valid_workers(self, ts: TaskState) -> set[WorkerState] | None:
else:
s &= ww
- if s is None:
- if len(self.running) < len(self.workers):
- return self.running.copy()
- else:
+ if s:
s = {self.workers[addr] for addr in s}
if len(self.running) < len(self.workers):
s &= self.running
@@ -2875,21 +3045,36 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState):
for ts in ws.processing:
steal.recalculate_cost(ts)
- def bulk_schedule_after_adding_worker(self, ws: WorkerState):
- """Send tasks with ts.state=='no-worker' in bulk to a worker that just joined.
- Return recommendations. As the worker will start executing the new tasks
- immediately, without waiting for the batch to end, we can't rely on worker-side
- ordering, so the recommendations are sorted by priority order here.
+ def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs:
+ """Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can handle.
+
+ Returns priority-ordered recommendations.
"""
- tasks = []
+ maybe_runnable: list[TaskState] = []
+ # Schedule any queued tasks onto the new worker
+ if not math.isinf(self.WORKER_SATURATION) and self.queued:
+ for qts in reversed(
+ list(
+ self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION))
+ )
+ ):
+ if self.validate:
+ assert qts.state == "queued"
+ assert not qts.processing_on
+ assert not qts.waiting_on
+
+ maybe_runnable.append(qts)
+
+ # Schedule any restricted tasks onto the new worker, if the worker can run them
for ts in self.unrunnable:
valid = self.valid_workers(ts)
if valid is None or ws in valid:
- tasks.append(ts)
- # These recommendations will generate {"op": "compute-task"} messages
- # to the worker in reversed order
- tasks.sort(key=operator.attrgetter("priority"), reverse=True)
- return {ts.key: "waiting" for ts in tasks}
+ maybe_runnable.append(ts)
+
+ # Recommendations are processed LIFO, hence the reversed order
+ maybe_runnable.sort(key=operator.attrgetter("priority"), reverse=True)
+ # Note not all will necessarily be run; transition->processing will decide
+ return {ts.key: "processing" for ts in maybe_runnable}
class Scheduler(SchedulerState, ServerNode):
@@ -3129,6 +3314,7 @@ def __init__(
self._last_client = None
self._last_time = 0
unrunnable = set()
+ queued: HeapSet[TaskState] = HeapSet(key=operator.attrgetter("priority"))
self.datasets = {}
@@ -3263,6 +3449,7 @@ def __init__(
resources=resources,
tasks=tasks,
unrunnable=unrunnable,
+ queued=queued,
validate=validate,
plugins=plugins,
transition_counter_max=transition_counter_max,
@@ -3806,9 +3993,6 @@ async def add_worker(
self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop)
- if ws.nthreads > len(ws.processing):
- self.idle[ws.address] = ws
-
for plugin in list(self.plugins.values()):
try:
result = plugin.add_worker(scheduler=self, worker=address)
@@ -4536,6 +4720,7 @@ def validate_released(self, key):
assert not ts.processing_on
assert not any([ts in dts.waiters for dts in ts.dependencies])
assert ts not in self.unrunnable
+ assert ts not in self.queued
def validate_waiting(self, key):
ts: TaskState = self.tasks[key]
@@ -4543,19 +4728,35 @@ def validate_waiting(self, key):
assert not ts.who_has
assert not ts.processing_on
assert ts not in self.unrunnable
+ assert ts not in self.queued
for dts in ts.dependencies:
# We are waiting on a dependency iff it's not stored
assert bool(dts.who_has) != (dts in ts.waiting_on)
assert ts in dts.waiters # XXX even if dts._who_has?
+ def validate_queued(self, key):
+ ts: TaskState = self.tasks[key]
+ dts: TaskState
+ assert ts in self.queued
+ assert not ts.waiting_on
+ assert not ts.who_has
+ assert not ts.processing_on
+ assert not (
+ ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions
+ )
+ for dts in ts.dependencies:
+ assert dts.who_has
+ assert ts in dts.waiters
+
def validate_processing(self, key):
ts: TaskState = self.tasks[key]
dts: TaskState
assert not ts.waiting_on
- ws: WorkerState = ts.processing_on
+ ws = ts.processing_on
assert ws
assert ts in ws.processing
assert not ts.who_has
+ assert ts not in self.queued
for dts in ts.dependencies:
assert dts.who_has
assert ts in dts.waiters
@@ -4568,9 +4769,10 @@ def validate_memory(self, key):
assert not ts.processing_on
assert not ts.waiting_on
assert ts not in self.unrunnable
+ assert ts not in self.queued
for dts in ts.dependents:
assert (dts in ts.waiters) == (
- dts.state in ("waiting", "processing", "no-worker")
+ dts.state in ("waiting", "queued", "processing", "no-worker")
)
assert ts not in dts.waiting_on
@@ -4581,6 +4783,7 @@ def validate_no_worker(self, key):
assert ts in self.unrunnable
assert not ts.processing_on
assert not ts.who_has
+ assert ts not in self.queued
for dts in ts.dependencies:
assert dts.who_has
@@ -4588,6 +4791,7 @@ def validate_erred(self, key):
ts: TaskState = self.tasks[key]
assert ts.exception_blame
assert not ts.who_has
+ assert ts not in self.queued
def validate_key(self, key, ts: TaskState | None = None):
try:
@@ -4619,13 +4823,21 @@ def validate_state(self, allow_overlap: bool = False) -> None:
if not (set(self.workers) == set(self.stream_comms)):
raise ValueError("Workers not the same in all collections")
+ assert self.running.issuperset(self.idle.values()), (
+ self.running,
+ list(self.idle.values()),
+ )
for w, ws in self.workers.items():
assert isinstance(w, str), (type(w), w)
assert isinstance(ws, WorkerState), (type(ws), ws)
assert ws.address == w
+ if ws.status != Status.running:
+ assert ws.address not in self.idle
+ assert ws.long_running.issubset(ws.processing)
if not ws.processing:
assert not ws.occupancy
- assert ws.address in self.idle
+ if ws.status == Status.running:
+ assert ws.address in self.idle
assert (ws.status == Status.running) == (ws in self.running)
for ws in self.running:
@@ -4892,13 +5104,13 @@ def handle_long_running(
self.check_idle_saturated(ws)
def handle_worker_status_change(
- self, status: str, worker: str, stimulus_id: str
+ self, status: str | Status, worker: str | WorkerState, stimulus_id: str
) -> None:
- ws = self.workers.get(worker)
+ ws = self.workers.get(worker) if isinstance(worker, str) else worker
if not ws:
return
prev_status = ws.status
- ws.status = Status.lookup[status] # type: ignore
+ ws.status = Status[status] if isinstance(status, str) else status
if ws.status == prev_status:
return
@@ -4907,12 +5119,14 @@ def handle_worker_status_change(
{
"action": "worker-status-change",
"prev-status": prev_status.name,
- "status": status,
+ "status": ws.status.name,
},
)
+ logger.debug(f"Worker status {prev_status.name} -> {status} - {ws}")
if ws.status == Status.running:
self.running.add(ws)
+ self.check_idle_saturated(ws)
recs = self.bulk_schedule_after_adding_worker(ws)
if recs:
client_msgs: dict = {}
@@ -4921,6 +5135,7 @@ def handle_worker_status_change(
self.send_all(client_msgs, worker_msgs)
else:
self.running.discard(ws)
+ self.idle.pop(ws.address, None)
async def handle_request_refresh_who_has(
self, keys: Iterable[str], worker: str, stimulus_id: str
@@ -6235,8 +6450,9 @@ async def retire_workers(
# Change Worker.status to closing_gracefully. Immediately set
# the same on the scheduler to prevent race conditions.
prev_status = ws.status
- ws.status = Status.closing_gracefully
- self.running.discard(ws)
+ self.handle_worker_status_change(
+ Status.closing_gracefully, ws, stimulus_id
+ )
# FIXME: We should send a message to the nanny first;
# eventually workers won't be able to close their own nannies.
self.stream_comms[ws.address].send(
@@ -7272,7 +7488,11 @@ def check_idle(self):
self.idle_since = None
return
- if any([ws.processing for ws in self.workers.values()]) or self.unrunnable:
+ if (
+ self.queued
+ or self.unrunnable
+ or any([ws.processing for ws in self.workers.values()])
+ ):
self.idle_since = None
return
@@ -7308,21 +7528,24 @@ def adaptive_target(self, target_duration=None):
target_duration = parse_timedelta(target_duration)
# CPU
+
+ # TODO consider any user-specified default task durations for queued tasks
+ queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION
cpu = math.ceil(
- self.total_occupancy / target_duration
+ (self.total_occupancy + queued_occupancy) / target_duration
) # TODO: threads per worker
# Avoid a few long tasks from asking for many cores
- tasks_processing = 0
+ tasks_ready = len(self.queued)
for ws in self.workers.values():
- tasks_processing += len(ws.processing)
+ tasks_ready += len(ws.processing)
- if tasks_processing > cpu:
+ if tasks_ready > cpu:
break
else:
- cpu = min(tasks_processing, cpu)
+ cpu = min(tasks_ready, cpu)
- if self.unrunnable and not self.workers:
+ if (self.unrunnable or self.queued) and not self.workers:
cpu = max(1, cpu)
# add more workers if more than 60% of memory is used
@@ -7397,7 +7620,44 @@ def request_remove_replicas(
)
-def _remove_from_processing(state: SchedulerState, ts: TaskState) -> WorkerState | None:
+def _validate_ready(state: SchedulerState, ts: TaskState):
+ "Validation for ready states (processing, queued, no-worker)"
+ assert not ts.waiting_on
+ assert not ts.who_has
+ assert not ts.exception_blame
+ assert not ts.processing_on
+ assert not ts.has_lost_dependencies
+ assert ts not in state.unrunnable
+ assert ts not in state.queued
+ assert all(dts.who_has for dts in ts.dependencies)
+
+
+def _add_to_processing(
+ state: SchedulerState, ts: TaskState, ws: WorkerState
+) -> dict[str, list]:
+ "Set a task as processing on a worker, and return the worker messages to send."
+ if state.validate:
+ _validate_ready(state, ts)
+ assert ts not in ws.processing
+ assert ws in state.running, state.running
+ assert (o := state.workers.get(ws.address)) is ws, (ws, o)
+
+ state._set_duration_estimate(ts, ws)
+ ts.processing_on = ws
+ ts.state = "processing"
+ state.acquire_resources(ts, ws)
+ state.check_idle_saturated(ws)
+ state.n_tasks += 1
+
+ if ts.actor:
+ ws.actors.add(ts)
+
+ return {ws.address: [_task_to_msg(state, ts)]}
+
+
+def _exit_processing_common(
+ state: SchedulerState, ts: TaskState, recommendations: Recs
+) -> WorkerState | None:
"""Remove *ts* from the set of processing tasks.
Returns
@@ -7428,6 +7688,18 @@ def _remove_from_processing(state: SchedulerState, ts: TaskState) -> WorkerState
state.check_idle_saturated(ws)
state.release_resources(ts, ws)
+ # If a slot has opened up for a queued task, schedule it.
+ if state.queued and not _worker_full(ws, state.WORKER_SATURATION):
+ qts = state.queued.peek()
+ if state.validate:
+ assert qts.state == "queued", qts.state
+ assert qts.key not in recommendations, recommendations[qts.key]
+
+ # NOTE: we don't need to schedule more than one task at once here. Since this is
+ # called each time 1 task completes, multiple tasks must complete for multiple
+ # slots to open up.
+ recommendations[qts.key] = "processing"
+
return ws
@@ -7489,6 +7761,32 @@ def _add_to_memory(
)
+def _propagage_released(
+ state: SchedulerState,
+ ts: TaskState,
+ recommendations: Recs,
+) -> None:
+ ts.state = "released"
+ key = ts.key
+
+ if ts.has_lost_dependencies:
+ recommendations[key] = "forgotten"
+ elif ts.waiters or ts.who_wants:
+ recommendations[key] = "waiting"
+
+ if recommendations.get(key) != "waiting":
+ for dts in ts.dependencies:
+ if dts.state != "released":
+ dts.waiters.discard(ts)
+ if not dts.waiters and not dts.who_wants:
+ recommendations[dts.key] = "released"
+ ts.waiters.clear()
+
+ if state.validate:
+ assert not ts.processing_on
+ assert ts not in state.queued
+
+
def _propagate_forgotten(
state: SchedulerState,
ts: TaskState,
@@ -7686,7 +7984,7 @@ def validate_task_state(ts: TaskState) -> None:
str(dts),
str(dts.dependents),
)
- if ts.state in ("waiting", "processing", "no-worker"):
+ if ts.state in ("waiting", "queued", "processing", "no-worker"):
assert dts in ts.waiting_on or dts.who_has, (
"dep missing",
str(ts),
@@ -7695,7 +7993,7 @@ def validate_task_state(ts: TaskState) -> None:
assert dts.state != "forgotten"
for dts in ts.waiters:
- assert dts.state in ("waiting", "processing", "no-worker"), (
+ assert dts.state in ("waiting", "queued", "processing", "no-worker"), (
"waiter not in play",
str(ts),
str(dts),
@@ -7712,6 +8010,15 @@ def validate_task_state(ts: TaskState) -> None:
assert (ts.processing_on is not None) == (ts.state == "processing")
assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state)
+ if ts.state == "queued":
+ assert not ts.processing_on
+ assert not ts.who_has
+ assert all(dts.who_has for dts in ts.dependencies), (
+ "task queued without all deps",
+ str(ts),
+ str(ts.dependencies),
+ )
+
if ts.state == "processing":
assert all(dts.who_has for dts in ts.dependencies), (
"task processing without all deps",
@@ -7752,6 +8059,7 @@ def validate_task_state(ts: TaskState) -> None:
if ts.state == "processing":
assert ts.processing_on
assert ts in ts.processing_on.actors
+ assert ts.state != "queued"
def validate_worker_state(ws: WorkerState) -> None:
@@ -7806,6 +8114,21 @@ def heartbeat_interval(n: int) -> float:
return n / 200 + 1
+def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int:
+ "Number of tasks that can be sent to this worker without oversaturating it"
+ assert not math.isinf(saturation_factor)
+ nthreads = ws.nthreads
+ return max(int(saturation_factor * nthreads), 1) - (
+ len(ws.processing) - len(ws.long_running)
+ )
+
+
+def _worker_full(ws: WorkerState, saturation_factor: float) -> bool:
+ if math.isinf(saturation_factor):
+ return False
+ return _task_slots_available(ws, saturation_factor) <= 0
+
+
class KilledWorker(Exception):
def __init__(self, task: str, last_worker: WorkerState):
super().__init__(task, last_worker)
diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py
index e99297a2a0..4c3c513adc 100644
--- a/distributed/tests/test_active_memory_manager.py
+++ b/distributed/tests/test_active_memory_manager.py
@@ -782,6 +782,8 @@ async def test_RetireWorker_no_remove(c, s, a, b):
while s.tasks["x"].who_has != {s.workers[b.address]}:
await asyncio.sleep(0.01)
assert a.address in s.workers
+ assert a.status == Status.closing_gracefully
+ assert s.workers[a.address].status == Status.closing_gracefully
# Policy has been removed without waiting for worker to disappear from
# Scheduler.workers
assert not s.extensions["amm"].policies
diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py
index 05850924f8..f83e5f45bb 100644
--- a/distributed/tests/test_scheduler.py
+++ b/distributed/tests/test_scheduler.py
@@ -3,6 +3,7 @@
import asyncio
import json
import logging
+import math
import operator
import pickle
import re
@@ -10,7 +11,7 @@
from itertools import product
from textwrap import dedent
from time import sleep
-from typing import Collection
+from typing import ClassVar, Collection
from unittest import mock
import cloudpickle
@@ -43,6 +44,7 @@
from distributed.utils import TimeoutError
from distributed.utils_test import (
BrokenComm,
+ async_wait_for,
captured_logger,
cluster,
dec,
@@ -54,6 +56,7 @@
raises_with_cause,
slowadd,
slowdec,
+ slowidentity,
slowinc,
tls_only_security,
varying,
@@ -145,7 +148,9 @@ def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
@gen_cluster(
client=True,
nthreads=nthreads,
- config={"distributed.scheduler.work-stealing": False},
+ config={
+ "distributed.scheduler.work-stealing": False,
+ },
)
async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers):
r"""
@@ -245,6 +250,223 @@ def random(**kwargs):
test_decide_worker_coschedule_order_neighbors_()
+@pytest.mark.parametrize("ngroups", [1, 2, 3, 5])
+@gen_cluster(
+ client=True,
+ nthreads=[("", 1), ("", 1)],
+)
+async def test_decide_worker_coschedule_order_binary_op(c, s, a, b, ngroups):
+ roots = [[delayed(i, name=f"x-{n}-{i}") for i in range(8)] for n in range(ngroups)]
+ zs = [sum(rs) for rs in zip(*roots)]
+
+ await c.gather(c.compute(zs))
+
+ assert not a.transfer_incoming_log, [l["keys"] for l in a.transfer_incoming_log]
+ assert not b.transfer_incoming_log, [l["keys"] for l in b.transfer_incoming_log]
+
+
+@pytest.mark.slow
+@gen_cluster(
+ nthreads=[("", 2)] * 4,
+ client=True,
+ config={"distributed.scheduler.worker-saturation": 1.0},
+)
+async def test_graph_execution_width(c, s, *workers):
+ """
+ Test that we don't execute the graph more breadth-first than necessary.
+
+ We shouldn't start loading extra data if we're not going to use it immediately.
+ The number of parallel work streams match the number of threads.
+ """
+
+ class Refcount:
+ "Track how many instances of this class exist; logs the count at creation and deletion"
+
+ count: ClassVar[int] = 0
+ lock: ClassVar[dask.utils.SerializableLock] = dask.utils.SerializableLock()
+ log: ClassVar[list[int]] = []
+
+ def __init__(self) -> None:
+ with self.lock:
+ type(self).count += 1
+ self.log.append(self.count)
+
+ def __del__(self):
+ with self.lock:
+ self.log.append(self.count)
+ type(self).count -= 1
+
+ roots = [delayed(Refcount)() for _ in range(32)]
+ passthrough1 = [delayed(slowidentity)(r, delay=0) for r in roots]
+ passthrough2 = [delayed(slowidentity)(r, delay=0) for r in passthrough1]
+ done = [delayed(lambda r: None)(r) for r in passthrough2]
+
+ fs = c.compute(done)
+ await wait(fs)
+ # NOTE: the max should normally equal `total_nthreads`. But some macOS CI machines
+ # are slow enough that they aren't able to reach the full parallelism of 8 threads.
+ assert max(Refcount.log) <= s.total_nthreads
+
+
+@pytest.mark.parametrize("queue", [True, False])
+@gen_cluster(
+ client=True,
+ nthreads=[("", 2)] * 2,
+ config={
+ "distributed.worker.memory.pause": False,
+ "distributed.worker.memory.target": False,
+ "distributed.worker.memory.spill": False,
+ "distributed.scheduler.work-stealing": False,
+ },
+)
+async def test_queued_paused_new_worker(c, s, a, b, queue):
+ if queue:
+ s.WORKER_SATURATION = 1.0
+ else:
+ s.WORKER_SATURATION = float("inf")
+
+ f1s = c.map(slowinc, range(16))
+ f2s = c.map(slowinc, f1s)
+ final = c.submit(sum, *f2s)
+ del f1s, f2s
+
+ while not a.data or not b.data:
+ await asyncio.sleep(0.01)
+
+ # manually pause the workers
+ a.status = Status.paused
+ b.status = Status.paused
+
+ while s.running:
+ # wait for workers pausing to hit the scheduler
+ await asyncio.sleep(0.01)
+
+ assert not s.idle
+ assert not s.running
+
+ async with Worker(s.address, nthreads=2) as w:
+ # Tasks are successfully scheduled onto a new worker
+ while not w.state.data:
+ await asyncio.sleep(0.01)
+
+ del final
+ while s.tasks:
+ await asyncio.sleep(0.01)
+ assert not s.queued
+
+
+@pytest.mark.parametrize("queue", [True, False])
+@gen_cluster(
+ client=True,
+ nthreads=[("", 2)] * 2,
+ config={
+ "distributed.worker.memory.pause": False,
+ "distributed.worker.memory.target": False,
+ "distributed.worker.memory.spill": False,
+ "distributed.scheduler.work-stealing": False,
+ },
+)
+async def test_queued_paused_unpaused(c, s, a, b, queue):
+ if queue:
+ s.WORKER_SATURATION = 1.0
+ else:
+ s.WORKER_SATURATION = float("inf")
+
+ f1s = c.map(slowinc, range(16))
+ f2s = c.map(slowinc, f1s)
+ final = c.submit(sum, *f2s)
+ del f1s, f2s
+
+ while not a.data or not b.data:
+ await asyncio.sleep(0.01)
+
+ # manually pause the workers
+ a.status = Status.paused
+ b.status = Status.paused
+
+ while s.running:
+ # wait for workers pausing to hit the scheduler
+ await asyncio.sleep(0.01)
+
+ assert not s.running
+ assert not s.idle
+
+ # un-pause
+ a.status = Status.running
+ b.status = Status.running
+ while not s.running:
+ await asyncio.sleep(0.01)
+
+ assert not s.idle # workers should have been (or already were) filled
+
+ await wait(final)
+
+
+@gen_cluster(
+ client=True,
+ nthreads=[("", 2)] * 2,
+ config={"distributed.scheduler.worker-saturation": 1.0},
+)
+async def test_queued_remove_add_worker(c, s, a, b):
+ event = Event()
+ fs = c.map(lambda i: event.wait(), range(10))
+
+ await async_wait_for(lambda: len(s.queued) == 6, timeout=5)
+ await s.remove_worker(a.address, stimulus_id="fake")
+ assert len(s.queued) == 8
+
+ # Add a new worker
+ async with Worker(s.address, nthreads=2) as w:
+ await async_wait_for(lambda: len(s.queued) == 6, timeout=5)
+
+ await event.set()
+ await wait(fs)
+
+
+@pytest.mark.parametrize(
+ "saturation, expected_task_counts",
+ [
+ (2.5, (5, 2)),
+ (2.0, (4, 2)),
+ (1.0, (2, 1)),
+ (-1.0, (1, 1)),
+ (float("inf"), (6, 4))
+ # ^ depends on root task assignment logic; ok if changes, just needs to add up to 10
+ ],
+)
+def test_saturation_factor(
+ saturation: int | float, expected_task_counts: tuple[int, int]
+) -> None:
+ @gen_cluster(
+ client=True,
+ nthreads=[("", 2), ("", 1)],
+ config={
+ "distributed.scheduler.worker-saturation": saturation,
+ },
+ )
+ async def _test_saturation_factor(c, s, a, b):
+ event = Event()
+ fs = c.map(
+ lambda _: event.wait(), range(10), key=[f"wait-{i}" for i in range(10)]
+ )
+ while a.state.executing_count < min(
+ a.state.nthreads, expected_task_counts[0]
+ ) or b.state.executing_count < min(b.state.nthreads, expected_task_counts[1]):
+ await asyncio.sleep(0.01)
+
+ if math.isfinite(saturation):
+ assert len(a.state.tasks) == expected_task_counts[0]
+ assert len(b.state.tasks) == expected_task_counts[1]
+ else:
+ # Assignment is nondeterministic for some reason without queuing
+ assert len(a.state.tasks) > len(b.state.tasks)
+
+ await event.set()
+ await c.gather(fs)
+
+ _test_saturation_factor()
+
+
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_move_data_over_break_restrictions(client, s, a, b, c):
[x] = await client.scatter([1], workers=b.address)
@@ -284,17 +506,32 @@ async def test_no_valid_workers_loose_restrictions(client, s, a, b, c):
assert result == 2
+@pytest.mark.parametrize("queue", [False, True])
@gen_cluster(client=True, nthreads=[])
-async def test_no_workers(client, s):
+async def test_no_workers(client, s, queue):
+ if queue:
+ s.WORKER_SATURATION = 1.0
+ else:
+ s.WORKER_SATURATION = float("inf")
+
x = client.submit(inc, 1)
while not s.tasks:
await asyncio.sleep(0.01)
- assert s.tasks[x.key] in s.unrunnable
+ ts = s.tasks[x.key]
+ if queue:
+ assert ts in s.queued
+ assert ts.state == "queued"
+ else:
+ assert ts in s.unrunnable
+ assert ts.state == "no-worker"
with pytest.raises(TimeoutError):
await asyncio.wait_for(x, 0.05)
+ async with Worker(s.address, nthreads=1):
+ await wait(x)
+
@gen_cluster(nthreads=[])
async def test_retire_workers_empty(s):
@@ -321,7 +558,7 @@ async def test_remove_worker_from_scheduler(s, a, b):
assert a.address in s.stream_comms
await s.remove_worker(address=a.address, stimulus_id="test")
assert a.address not in s.workers
- assert len(s.workers[b.address].processing) == len(dsk) # b owns everything
+ assert len(s.workers[b.address].processing) + len(s.queued) == len(dsk)
@gen_cluster()
@@ -605,12 +842,31 @@ async def test_ready_remove_worker(s, a, b):
dependencies={"x-%d" % i: [] for i in range(20)},
)
- assert all(len(w.processing) > w.nthreads for w in s.workers.values())
+ if s.WORKER_SATURATION == 1:
+ cmp = operator.eq
+ elif math.isinf(s.WORKER_SATURATION):
+ cmp = operator.gt
+ else:
+ pytest.fail(f"{s.WORKER_OVERSATURATION=}, must be 1 or inf")
+
+ assert all(cmp(len(w.processing), w.nthreads) for w in s.workers.values()), (
+ list(s.workers.values()),
+ s.WORKER_SATURATION,
+ )
+ assert sum(len(w.processing) for w in s.workers.values()) + len(s.queued) == len(
+ s.tasks
+ )
await s.remove_worker(address=a.address, stimulus_id="test")
assert set(s.workers) == {b.address}
- assert all(len(w.processing) > w.nthreads for w in s.workers.values())
+ assert all(cmp(len(w.processing), w.nthreads) for w in s.workers.values()), (
+ list(s.workers.values()),
+ s.WORKER_SATURATION,
+ )
+ assert sum(len(w.processing) for w in s.workers.values()) + len(s.queued) == len(
+ s.tasks
+ )
@gen_cluster(client=True, Worker=Nanny, timeout=60)
@@ -1180,9 +1436,13 @@ async def test_learn_occupancy(c, s, a, b):
while sum(len(ts.who_has) for ts in s.tasks.values()) < 10:
await asyncio.sleep(0.01)
- assert 100 < s.total_occupancy < 1000
+ nproc = sum(ts.state == "processing" for ts in s.tasks.values())
+ assert nproc * 0.1 < s.total_occupancy < nproc * 0.4
for w in [a, b]:
- assert 50 < s.workers[w.address].occupancy < 700
+ ws = s.workers[w.address]
+ occ = ws.occupancy
+ proc = len(ws.processing)
+ assert proc * 0.1 < occ < proc * 0.4
@pytest.mark.slow
@@ -1193,7 +1453,8 @@ async def test_learn_occupancy_2(c, s, a, b):
while not any(ts.who_has for ts in s.tasks.values()):
await asyncio.sleep(0.01)
- assert 100 < s.total_occupancy < 1000
+ nproc = sum(ts.state == "processing" for ts in s.tasks.values())
+ assert nproc * 0.1 < s.total_occupancy < nproc * 0.4
@gen_cluster(client=True)
@@ -2011,6 +2272,31 @@ async def test_adaptive_target(c, s, a, b):
assert s.adaptive_target(target_duration=".1s") == 0
+@pytest.mark.parametrize("queue", [True, False])
+@gen_cluster(
+ client=True,
+ nthreads=[],
+)
+async def test_adaptive_target_empty_cluster(c, s, queue):
+ if queue:
+ s.WORKER_SATURATION = 1.0
+ else:
+ s.WORKER_SATURATION = float("inf")
+
+ assert s.adaptive_target() == 0
+
+ f = c.submit(inc, -1)
+ await async_wait_for(lambda: s.tasks, timeout=5)
+ assert s.adaptive_target() == 1
+ del f
+
+ if queue:
+ # only queuing supports fast scale-up for empty clusters https://github.com/dask/distributed/issues/6962
+ fs = c.map(inc, range(100))
+ await async_wait_for(lambda: len(s.tasks) == len(fs), timeout=5)
+ assert s.adaptive_target() > 1
+
+
@gen_test()
async def test_async_context_manager():
async with Scheduler(dashboard_address=":0") as s:
diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py
index 0103808cf5..1fe036a371 100644
--- a/distributed/tests/test_worker.py
+++ b/distributed/tests/test_worker.py
@@ -3116,6 +3116,11 @@ async def test_worker_status_sync(s, a):
"prev-status": "paused",
"status": "running",
},
+ {
+ "action": "worker-status-change",
+ "prev-status": "running",
+ "status": "closing_gracefully",
+ },
{"action": "remove-worker", "processing-tasks": {}},
{"action": "retired"},
]
diff --git a/docs/source/images/task-state.dot b/docs/source/images/task-state.dot
index fde7d8c62c..f68008c6a2 100644
--- a/docs/source/images/task-state.dot
+++ b/docs/source/images/task-state.dot
@@ -7,7 +7,10 @@ digraph{
released2 [label=released];
released1 -> waiting;
waiting -> processing;
- waiting -> "no-worker" [dir=both];
+ waiting -> "no-worker";
+ waiting -> queued;
+ "no-worker" -> processing;
+ queued -> processing;
processing -> memory;
processing -> error;
error -> released2;
diff --git a/docs/source/images/task-state.svg b/docs/source/images/task-state.svg
index 7b7d5c84ad..b8d23318be 100644
--- a/docs/source/images/task-state.svg
+++ b/docs/source/images/task-state.svg
@@ -1,109 +1,132 @@
-
-