diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py new file mode 100644 index 0000000000..4ed2daf411 --- /dev/null +++ b/distributed/active_memory_manager.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import asyncio +from collections import defaultdict +from collections.abc import Generator +from typing import TYPE_CHECKING, Optional + +from tornado.ioloop import PeriodicCallback + +import dask +from dask.utils import parse_timedelta + +from .utils import import_term + +if TYPE_CHECKING: + from .scheduler import SchedulerState, TaskState, WorkerState + + +class ActiveMemoryManagerExtension: + """Scheduler extension that optimizes memory usage across the cluster. + It can be either triggered by hand or automatically every few seconds; at every + iteration it performs one or both of the following: + + - create new replicas of in-memory tasks + - destroy replicas of in-memory tasks; this never destroys the last available copy. + + There are no 'move' operations. A move is performed in two passes: first you create + a copy and, in the next iteration, you delete the original (if the copy succeeded). + + This extension is configured by the dask config section + ``distributed.scheduler.active-memory-manager``. + """ + + scheduler: SchedulerState + policies: set[ActiveMemoryManagerPolicy] + interval: float + + # These attributes only exist within the scope of self.run() + # Current memory (in bytes) allocated on each worker, plus/minus pending actions + workers_memory: dict[WorkerState, int] + # Pending replications and deletions for each task + pending: defaultdict[TaskState, tuple[set[WorkerState], set[WorkerState]]] + + def __init__( + self, + scheduler: SchedulerState, + # The following parameters are exposed so that one may create, run, and throw + # away on the fly a specialized manager, separate from the main one. + policies: Optional[set[ActiveMemoryManagerPolicy]] = None, + register: bool = True, + start: Optional[bool] = None, + interval: Optional[float] = None, + ): + self.scheduler = scheduler + + if policies is None: + policies = set() + for kwargs in dask.config.get( + "distributed.scheduler.active-memory-manager.policies" + ): + kwargs = kwargs.copy() + cls = import_term(kwargs.pop("class")) + if not issubclass(cls, ActiveMemoryManagerPolicy): + raise TypeError( + f"{cls}: Expected ActiveMemoryManagerPolicy; got {type(cls)}" + ) + policies.add(cls(**kwargs)) + + for policy in policies: + policy.manager = self + self.policies = policies + + if register: + scheduler.extensions["amm"] = self + scheduler.handlers.update( + { + "amm_run_once": self.run_once, + "amm_start": self.start, + "amm_stop": self.stop, + } + ) + + if interval is None: + interval = parse_timedelta( + dask.config.get("distributed.scheduler.active-memory-manager.interval") + ) + self.interval = interval + if start is None: + start = dask.config.get("distributed.scheduler.active-memory-manager.start") + if start: + self.start() + + def start(self, comm=None) -> None: + """Start executing every ``self.interval`` seconds until scheduler shutdown""" + pc = PeriodicCallback(self.run_once, self.interval * 1000.0) + self.scheduler.periodic_callbacks["amm"] = pc + pc.start() + + def stop(self, comm=None) -> None: + """Stop periodic execution""" + pc = self.scheduler.periodic_callbacks.pop("amm", None) + if pc: + pc.stop() + + def run_once(self, comm=None) -> None: + """Run all policies once and asynchronously (fire and forget) enact their + recommendations to replicate/drop keys + """ + # This should never fail since this is a synchronous method + assert not hasattr(self, "pending") + + self.pending = defaultdict(lambda: (set(), set())) + self.workers_memory = { + w: w.memory.optimistic for w in self.scheduler.workers.values() + } + try: + # populate self.pending + self._run_policies() + + drop_by_worker = defaultdict(set) + repl_by_worker = defaultdict(dict) + for ts, (pending_repl, pending_drop) in self.pending.items(): + if not ts.who_has: + continue + who_has = [ws_snd.address for ws_snd in ts.who_has - pending_drop] + assert who_has # Never drop the last replica + for ws_rec in pending_repl: + assert ws_rec not in ts.who_has + repl_by_worker[ws_rec.address][ts.key] = who_has + for ws in pending_drop: + assert ws in ts.who_has + drop_by_worker[ws.address].add(ts.key) + + # Fire-and-forget enact recommendations from policies + # This is temporary code, waiting for + # https://github.com/dask/distributed/pull/5046 + for addr, who_has in repl_by_worker.items(): + asyncio.create_task(self.scheduler.gather_on_worker(addr, who_has)) + for addr, keys in drop_by_worker.items(): + asyncio.create_task(self.scheduler.delete_worker_data(addr, keys)) + # End temporary code + + finally: + del self.workers_memory + del self.pending + + def _run_policies(self) -> None: + """Sequentially run ActiveMemoryManagerPolicy.run() for all registered policies, + obtain replicate/drop suggestions, and use them to populate self.pending. + """ + candidates: Optional[set[WorkerState]] + cmd: str + ws: Optional[WorkerState] + ts: TaskState + nreplicas: int + + for policy in list(self.policies): # a policy may remove itself + policy_gen = policy.run() + ws = None + while True: + try: + cmd, ts, candidates = policy_gen.send(ws) + except StopIteration: + break # next policy + + pending_repl, pending_drop = self.pending[ts] + + if cmd == "replicate": + ws = self._find_recipient(ts, candidates, pending_repl) + if ws: + pending_repl.add(ws) + self.workers_memory[ws] += ts.nbytes + + elif cmd == "drop": + ws = self._find_dropper(ts, candidates, pending_drop) + if ws: + pending_drop.add(ws) + self.workers_memory[ws] = max( + 0, self.workers_memory[ws] - ts.nbytes + ) + + else: + raise ValueError(f"Unknown command: {cmd}") # pragma: nocover + + def _find_recipient( + self, + ts: TaskState, + candidates: Optional[set[WorkerState]], + pending_repl: set[WorkerState], + ) -> Optional[WorkerState]: + """Choose a worker to acquire a new replica of an in-memory task among a set of + candidates. If candidates is None, default to all workers in the cluster that do + not hold a replica yet. The worker with the lowest memory usage (downstream of + pending replications and drops) will be returned. + """ + if ts.state != "memory": + return None + if candidates is None: + candidates = set(self.scheduler.workers.values()) + candidates -= ts.who_has + candidates -= pending_repl + if not candidates: + return None + return min(candidates, key=self.workers_memory.get) + + def _find_dropper( + self, + ts: TaskState, + candidates: Optional[set[WorkerState]], + pending_drop: set[WorkerState], + ) -> Optional[WorkerState]: + """Choose a worker to drop its replica of an in-memory task among a set of + candidates. If candidates is None, default to all workers in the cluster that + hold a replica. The worker with the highest memory usage (downstream of pending + replications and drops) will be returned. + """ + if len(ts.who_has) - len(pending_drop) < 2: + return None + if candidates is None: + candidates = ts.who_has.copy() + else: + candidates &= ts.who_has + candidates -= pending_drop + candidates -= {waiter_ts.processing_on for waiter_ts in ts.waiters} + if not candidates: + return None + return max(candidates, key=self.workers_memory.get) + + +class ActiveMemoryManagerPolicy: + """Abstract parent class""" + + manager: ActiveMemoryManagerExtension + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + def run( + self, + ) -> Generator[ + tuple[str, TaskState, Optional[set[WorkerState]]], + Optional[WorkerState], + None, + ]: + """This method is invoked by the ActiveMemoryManager every few seconds, or + whenever the user invokes scheduler.amm_run_once(). + It is an iterator that must emit any of the following: + + - "replicate", , None + - "replicate", , {subset of potential workers to replicate to} + - "drop", , None + - "drop", , {subset of potential workers to drop from} + + Each element yielded indicates the desire to create or destroy a single replica + of a key. If a subset of workers is not provided, it defaults to all workers on + the cluster. Either the ActiveMemoryManager or the Worker may later decide to + disregard the request, e.g. because it would delete the last copy of a key or + because the key is currently needed on that worker. + + You may optionally retrieve which worker it was decided the key will be + replicated to or dropped from, as follows: + + ```python + choice = yield "replicate", ts, None + ``` + + ``choice`` is either a WorkerState or None; the latter is returned if the + ActiveMemoryManager chose to disregard the request. + + The current pending (accepted) commands can be inspected on + ``self.manager.pending``; this includes the commands previously yielded by this + same method. + + The current memory usage on each worker, *downstream of all pending commands*, + can be inspected on ``self.manager.workers_memory``. + """ + raise NotImplementedError("Virtual method") + + +class ReduceReplicas(ActiveMemoryManagerPolicy): + """Make sure that in-memory tasks are not replicated on more workers than desired; + drop the excess replicas. + """ + + def run(self): + # TODO this is O(n) to the total number of in-memory tasks on the cluster; it + # could be made faster by automatically attaching it to a TaskState when it + # goes above one replica and detaching it when it drops below two. + for ts in self.manager.scheduler.tasks.values(): + if len(ts.who_has) < 2: + continue + + desired_replicas = 1 # TODO have a marker on TaskState + + # If a dependent task has not been assigned to a worker yet, err on the side + # of caution and preserve an additional replica for it. + # However, if two dependent tasks have been already assigned to the same + # worker, don't double count them. + nwaiters = len({waiter.processing_on or waiter for waiter in ts.waiters}) + + ndrop = len(ts.who_has) - max(desired_replicas, nwaiters) + if ts in self.manager.pending: + pending_repl, pending_drop = self.manager.pending[ts] + ndrop += len(pending_repl) - len(pending_drop) + + # ndrop could be negative, which for range() is the same as 0. + for _ in range(ndrop): + yield "drop", ts, None diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index d10022f757..f5b7b73a5e 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -242,6 +242,31 @@ properties: A list of trusted root modules the schedular is allowed to import (incl. submodules). For security reasons, the scheduler does not import arbitrary Python modules. + active-memory-manager: + type: object + required: [start, interval, policies] + additionalProperties: false + properties: + start: + type: boolean + description: set to true to auto-start the AMM on Scheduler init; + false to manually start it with client.scheduler.amm_start() + interval: + type: string + description: + Time expression, e.g. "2s". Run the AMM cycle every . + policies: + type: array + items: + type: object + required: [class] + properties: + class: + type: string + description: fully qualified name of an ActiveMemoryManagerPolicy + subclass + additionalProperties: + description: keyword arguments to the policy constructor, if any worker: type: object diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 929b58676a..6949fdd56e 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -60,6 +60,18 @@ distributed: - dask - distributed + active-memory-manager: + # Set to true to auto-start the Active Memory Manager on Scheduler start; if false + # you'll have to either manually start it with client.scheduler.amm_start() or run + # it once with client.scheduler.amm_run(). + start: false + # Once started, run the AMM cycle every + interval: 2s + policies: + # Policies that should be executed at every cycle. Any additional keys in each + # object are passed as keyword arguments to the policy constructor. + - class: distributed.active_memory_manager.ReduceReplicas + worker: blocked-handlers: [] multiprocessing-method: spawn diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9ee4dc35a2..22a33372b5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -42,6 +42,7 @@ from . import preloading, profile from . import versions as version_module +from .active_memory_manager import ActiveMemoryManagerExtension from .batched import BatchedSend from .comm import ( get_address_host, @@ -172,6 +173,7 @@ def nogil(func): PubSubSchedulerExtension, SemaphoreExtension, EventExtension, + ActiveMemoryManagerExtension, ] ALL_TASK_STATES = declare( @@ -5859,8 +5861,8 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None): ) return d[worker] - async def _gather_on_worker( - self, worker_address: str, who_has: "dict[Hashable, list[str]]" + async def gather_on_worker( + self, worker_address: str, who_has: "dict[str, list[str]]" ) -> set: """Peer-to-peer copy of keys from multiple workers to a single worker @@ -5919,7 +5921,7 @@ async def _gather_on_worker( return keys_failed - async def _delete_worker_data(self, worker_address: str, keys: "list[str]") -> None: + async def delete_worker_data(self, worker_address: str, keys: "list[str]") -> None: """Delete data from a worker and update the corresponding worker/task states Parameters @@ -6290,7 +6292,7 @@ async def _rebalance_move_data( await asyncio.gather( *( # Note: this never raises exceptions - self._gather_on_worker(w, who_has) + self.gather_on_worker(w, who_has) for w, who_has in to_recipients.items() ) ), @@ -6304,7 +6306,7 @@ async def _rebalance_move_data( # Note: this never raises exceptions await asyncio.gather( - *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + *(self.delete_worker_data(r, v) for r, v in to_senders.items()) ) for r, v in to_recipients.items(): @@ -6390,7 +6392,7 @@ async def replicate( # Note: this never raises exceptions await asyncio.gather( *[ - self._delete_worker_data(ws._address, [t.key for t in tasks]) + self.delete_worker_data(ws._address, [t.key for t in tasks]) for ws, tasks in del_worker_tasks.items() ] ) @@ -6420,7 +6422,7 @@ async def replicate( await asyncio.gather( *( # Note: this never raises exceptions - self._gather_on_worker(w, who_has) + self.gather_on_worker(w, who_has) for w, who_has in gathers.items() ) ) diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py new file mode 100644 index 0000000000..afe6c11494 --- /dev/null +++ b/distributed/tests/test_active_memory_manager.py @@ -0,0 +1,439 @@ +import asyncio +import random + +import pytest + +from distributed import Nanny +from distributed.active_memory_manager import ( + ActiveMemoryManagerExtension, + ActiveMemoryManagerPolicy, +) +from distributed.utils_test import gen_cluster, inc, slowinc + +NO_AMM_START = {"distributed.scheduler.active-memory-manager.start": False} + + +@gen_cluster( + client=True, + config={ + "distributed.scheduler.active-memory-manager.start": False, + "distributed.scheduler.active-memory-manager.policies": [], + }, +) +async def test_no_policies(c, s, a, b): + await c.scheduler.amm_run_once() + + +class DemoPolicy(ActiveMemoryManagerPolicy): + """Drop or replicate a key n times""" + + def __init__(self, action, key, n, candidates): + self.action = action + self.key = key + self.n = n + self.candidates = candidates + + def run(self): + candidates = self.candidates + if candidates is not None: + candidates = { + ws + for i, ws in enumerate(self.manager.scheduler.workers.values()) + if i in candidates + } + for ts in self.manager.scheduler.tasks.values(): + if ts.key == self.key: + for _ in range(self.n): + yield self.action, ts, candidates + + +def demo_config(action, key="x", n=10, candidates=None, start=False, interval=0.1): + """Create a dask config for AMM with DemoPolicy""" + return { + "distributed.scheduler.active-memory-manager.start": start, + "distributed.scheduler.active-memory-manager.interval": interval, + "distributed.scheduler.active-memory-manager.policies": [ + { + "class": "distributed.tests.test_active_memory_manager.DemoPolicy", + "action": action, + "key": key, + "n": n, + "candidates": candidates, + }, + ], + } + + +@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("drop")) +async def test_drop(c, s, *workers): + futures = await c.scatter({"x": 123}, broadcast=True) + assert len(s.tasks["x"].who_has) == 4 + # Also test the extension handler + await c.scheduler.amm_run_once() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + # The last copy is never dropped even if the policy asks so + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 1 + + +@gen_cluster(client=True, config=demo_config("drop")) +async def test_start_stop(c, s, a, b): + x = c.submit(lambda: 123, key="x") + await c.replicate(x, 2) + assert len(s.tasks["x"].who_has) == 2 + await c.scheduler.amm_start() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + await c.scheduler.amm_stop() + # AMM is not running anymore + await c.replicate(x, 2) + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 2 + + +@gen_cluster(client=True, config=demo_config("drop", start=True, interval=0.1)) +async def test_auto_start(c, s, a, b): + futures = await c.scatter({"x": 123}, broadcast=True) + # The AMM should run within 0.1s of the broadcast. + # Add generous extra padding to prevent flakiness. + await asyncio.sleep(0.5) + assert len(s.tasks["x"].who_has) == 1 + + +@gen_cluster(client=True, config=NO_AMM_START) +async def test_not_registered(c, s, a, b): + futures = await c.scatter({"x": 1}, broadcast=True) + assert len(s.tasks["x"].who_has) == 2 + + class Policy(ActiveMemoryManagerPolicy): + def run(self): + yield "drop", s.tasks["x"], None + + amm = ActiveMemoryManagerExtension(s, {Policy()}, register=False, start=False) + amm.run_once() + assert amm is not s.extensions["amm"] + + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, config=demo_config("drop")) +async def test_drop_not_in_memory(c, s, a, b): + """ts.who_has is empty""" + x = c.submit(slowinc, 1, key="x") + while "x" not in s.tasks: + await asyncio.sleep(0.01) + assert not x.done() + s.extensions["amm"].run_once() + assert await x == 2 + + +@gen_cluster(client=True, config=demo_config("drop")) +async def test_drop_with_waiter(c, s, a, b): + """Tasks with a waiter are never dropped""" + x = (await c.scatter({"x": 1}, broadcast=True))["x"] + y1 = c.submit(slowinc, x, delay=0.4, key="y1", workers=[a.address]) + y2 = c.submit(slowinc, x, delay=0.8, key="y2", workers=[b.address]) + for key in ("y1", "y2"): + while key not in s.tasks or s.tasks[key].state != "processing": + await asyncio.sleep(0.01) + + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert {ws.address for ws in s.tasks["x"].who_has} == {a.address, b.address} + assert await y1 == 2 + # y1 is finished so there's a worker available without a waiter + s.extensions["amm"].run_once() + while {ws.address for ws in s.tasks["x"].who_has} != {b.address}: + await asyncio.sleep(0.01) + assert not y2.done() + + +@pytest.mark.xfail(reason="distributed#5265") +@gen_cluster(client=True, config=NO_AMM_START) +async def test_double_drop(c, s, a, b): + """An AMM drop policy runs once to drop one of the two replicas of a key. + Then it runs again, before the recommendations from the first iteration had the time + to either be enacted or rejected, and chooses a different worker to drop from. + + Test that, in this use case, the last replica of a key is never dropped. + """ + futures = await c.scatter({"x": 1}, broadcast=True) + assert len(s.tasks["x"].who_has) == 2 + ws_iter = iter(s.workers.values()) + + class Policy(ActiveMemoryManagerPolicy): + def run(self): + yield "drop", s.tasks["x"], {next(ws_iter)} + + amm = ActiveMemoryManagerExtension(s, {Policy()}, register=False, start=False) + amm.run_once() + amm.run_once() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 1 + + +@gen_cluster(client=True, config=demo_config("drop")) +async def test_double_drop_stress(c, s, a, b): + """AMM runs many times before the recommendations of the first run are enacted""" + futures = await c.scatter({"x": 1}, broadcast=True) + assert len(s.tasks["x"].who_has) == 2 + for _ in range(10): + s.extensions["amm"].run_once() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 1 + + +@pytest.mark.slow +@gen_cluster( + nthreads=[("", 1)] * 4, + Worker=Nanny, + client=True, + worker_kwargs={"memory_limit": "2 GiB"}, + config=demo_config("drop", n=1), +) +async def test_drop_from_worker_with_least_free_memory(c, s, *nannies): + a1, a2, a3, a4 = s.workers.keys() + ws1, ws2, ws3, ws4 = s.workers.values() + + futures = await c.scatter({"x": 1}, broadcast=True) + assert s.tasks["x"].who_has == {ws1, ws2, ws3, ws4} + # Allocate enough RAM to be safely more than unmanaged memory + clog = c.submit(lambda: "x" * 2 ** 29, workers=[a3]) # 512 MiB + # await wait(clog) is not enough; we need to wait for the heartbeats + while ws3.memory.optimistic < 2 ** 29: + await asyncio.sleep(0.01) + s.extensions["amm"].run_once() + + while s.tasks["x"].who_has != {ws1, ws2, ws4}: + await asyncio.sleep(0.01) + + +@gen_cluster( + nthreads=[("", 1)] * 8, + client=True, + config=demo_config("drop", n=1, candidates={5, 6}), +) +async def test_drop_with_candidates(c, s, *workers): + futures = await c.scatter({"x": 1}, broadcast=True) + s.extensions["amm"].run_once() + wss = list(s.workers.values()) + expect1 = {wss[0], wss[1], wss[2], wss[3], wss[4], wss[6], wss[7]} + expect2 = {wss[0], wss[1], wss[2], wss[3], wss[4], wss[5], wss[7]} + while s.tasks["x"].who_has not in (expect1, expect2): + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, config=demo_config("drop", candidates=set())) +async def test_drop_with_empty_candidates(c, s, a, b): + """Key is not dropped as the plugin proposes an empty set of candidates, + not to be confused with None + """ + futures = await c.scatter({"x": 1}, broadcast=True) + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 2 + + +@gen_cluster( + client=True, nthreads=[("", 1)] * 3, config=demo_config("drop", candidates={2}) +) +async def test_drop_from_candidates_without_key(c, s, *workers): + """Key is not dropped as none of the candidates hold a replica""" + ws0, ws1, ws2 = s.workers.values() + x = (await c.scatter({"x": 1}, workers=[ws0.address]))["x"] + y = c.submit(inc, x, key="y", workers=[ws1.address]) + await y + assert s.tasks["x"].who_has == {ws0, ws1} + + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert s.tasks["x"].who_has == {ws0, ws1} + + +@gen_cluster(client=True, config=demo_config("drop", candidates={0})) +async def test_drop_with_bad_candidates(c, s, a, b): + """Key is not dropped as all candidates hold waiter tasks""" + ws0, ws1 = s.workers.values() # Not necessarily a, b; it could be b, a! + x = (await c.scatter({"x": 1}, broadcast=True))["x"] + y = c.submit(slowinc, x, 0.3, key="y", workers=[ws0.address]) + while "y" not in s.tasks: + await asyncio.sleep(0.01) + + s.extensions["amm"].run_once() + await y + assert s.tasks["x"].who_has == {ws0, ws1} + + +class DropEverything(ActiveMemoryManagerPolicy): + """Inanely suggest to drop every single key in the cluster""" + + def run(self): + for ts in self.manager.scheduler.tasks.values(): + # Instead of yielding ("drop", ts, None) for each worker, which would result + # in semi-predictable output about which replica survives, randomly choose a + # different survivor at each AMM run. + candidates = list(ts.who_has) + random.shuffle(candidates) + for ws in candidates: + yield "drop", ts, {ws} + + +@pytest.mark.xfail(reason="distributed#5046, distributed#5265") +@pytest.mark.slow +@gen_cluster( + client=True, + nthreads=[("", 1)] * 8, + Worker=Nanny, + config={ + "distributed.scheduler.active-memory-manager.start": True, + "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.policies": [ + {"class": "distributed.tests.test_active_memory_manager.DropEverything"}, + ], + }, +) +async def test_drop_stress(c, s, *nannies): + """A policy which suggests dropping everything won't break a running computation, + but only slow it down. + """ + import dask.array as da + + rng = da.random.RandomState(0) + a = rng.random((20, 20), chunks=(1, 1)) + b = (a @ a.T).sum().round(3) + assert await c.compute(b) == 2134.398 + + +@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("replicate", n=2)) +async def test_replicate(c, s, *workers): + futures = await c.scatter({"x": 123}) + assert len(s.tasks["x"].who_has) == 1 + + s.extensions["amm"].run_once() + while len(s.tasks["x"].who_has) < 3: + await asyncio.sleep(0.01) + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 3 + + s.extensions["amm"].run_once() + while len(s.tasks["x"].who_has) < 4: + await asyncio.sleep(0.01) + + for w in workers: + assert w.data["x"] == 123 + + +@gen_cluster(client=True, config=demo_config("replicate")) +async def test_replicate_not_in_memory(c, s, a, b): + """ts.who_has is empty""" + x = c.submit(slowinc, 1, key="x") + while "x" not in s.tasks: + await asyncio.sleep(0.01) + assert not x.done() + s.extensions["amm"].run_once() + assert await x == 2 + assert len(s.tasks["x"].who_has) == 1 + s.extensions["amm"].run_once() + while len(s.tasks["x"].who_has) < 2: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, config=demo_config("replicate")) +async def test_double_replicate_stress(c, s, a, b): + """AMM runs many times before the recommendations of the first run are enacted""" + futures = await c.scatter({"x": 1}) + assert len(s.tasks["x"].who_has) == 1 + for _ in range(10): + s.extensions["amm"].run_once() + while len(s.tasks["x"].who_has) < 2: + await asyncio.sleep(0.01) + + +@pytest.mark.slow +@gen_cluster( + nthreads=[("", 1)] * 4, + Worker=Nanny, + client=True, + worker_kwargs={"memory_limit": "2 GiB"}, + config=demo_config("replicate", n=1), +) +async def test_replicate_to_worker_with_most_free_memory(c, s, *nannies): + a1, a2, a3, a4 = s.workers.keys() + ws1, ws2, ws3, ws4 = s.workers.values() + + futures = await c.scatter({"x": 1}, workers=[a1]) + assert s.tasks["x"].who_has == {ws1} + # Allocate enough RAM to be safely more than unmanaged memory + clog2 = c.submit(lambda: "x" * 2 ** 29, workers=[a2]) # 512 MiB + clog4 = c.submit(lambda: "x" * 2 ** 29, workers=[a4]) # 512 MiB + # await wait(clog) is not enough; we need to wait for the heartbeats + for ws in (ws2, ws4): + while ws.memory.optimistic < 2 ** 29: + await asyncio.sleep(0.01) + s.extensions["amm"].run_once() + + while s.tasks["x"].who_has != {ws1, ws3}: + await asyncio.sleep(0.01) + + +@gen_cluster( + nthreads=[("", 1)] * 8, + client=True, + config=demo_config("replicate", n=1, candidates={5, 6}), +) +async def test_replicate_with_candidates(c, s, *workers): + wss = list(s.workers.values()) + futures = await c.scatter({"x": 1}, workers=[wss[0].address]) + s.extensions["amm"].run_once() + expect1 = {wss[0], wss[5]} + expect2 = {wss[0], wss[6]} + while s.tasks["x"].who_has not in (expect1, expect2): + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, config=demo_config("replicate", candidates=set())) +async def test_replicate_with_empty_candidates(c, s, a, b): + """Key is not replicated as the plugin proposes an empty set of candidates, + not to be confused with None + """ + futures = await c.scatter({"x": 1}) + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 1 + + +@gen_cluster(client=True, config=demo_config("replicate", candidates={0})) +async def test_replicate_to_candidates_with_key(c, s, a, b): + """Key is not replicated as all candidates already hold replicas""" + ws0, ws1 = s.workers.values() # Not necessarily a, b; it could be b, a! + futures = await c.scatter({"x": 1}, workers=[ws0.address]) + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert s.tasks["x"].who_has == {ws0} + + +@gen_cluster( + nthreads=[("", 1)] * 4, + client=True, + config={ + "distributed.scheduler.active-memory-manager.start": False, + "distributed.scheduler.active-memory-manager.policies": [ + {"class": "distributed.active_memory_manager.ReduceReplicas"}, + # Run two instances of the plugin in sequence, to emulate multiple plugins + # that issues drop suggestions for the same keys + {"class": "distributed.active_memory_manager.ReduceReplicas"}, + ], + }, +) +async def test_ReduceReplicas(c, s, *workers): + futures = await c.scatter({"x": 123}, broadcast=True) + assert len(s.tasks["x"].who_has) == 4 + s.extensions["amm"].run_once() + while len(s.tasks["x"].who_has) > 1: + await asyncio.sleep(0.01) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 260513e954..23f59431cf 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2939,7 +2939,7 @@ async def test_gather_on_worker(c, s, a, b): assert x_ts not in b_ws.has_what assert x_ts.who_has == {a_ws} - out = await s._gather_on_worker(b.address, {x.key: [a.address]}) + out = await s.gather_on_worker(b.address, {x.key: [a.address]}) assert out == set() assert a.data[x.key] == "x" assert b.data[x.key] == "x" @@ -2955,14 +2955,14 @@ async def test_gather_on_worker_bad_recipient(c, s, a, b): x = await c.scatter("x") await b.close() assert s.workers.keys() == {a.address} - out = await s._gather_on_worker(b.address, {x.key: [a.address]}) + out = await s.gather_on_worker(b.address, {x.key: [a.address]}) assert out == {x.key} @gen_cluster(client=True, worker_kwargs={"timeout": "100ms"}) async def test_gather_on_worker_bad_sender(c, s, a, b): """The only sender for a key is missing""" - out = await s._gather_on_worker(a.address, {"x": ["tcp://127.0.0.1:12345"]}) + out = await s.gather_on_worker(a.address, {"x": ["tcp://127.0.0.1:12345"]}) assert out == {"x"} @@ -2974,7 +2974,7 @@ async def test_gather_on_worker_bad_sender_replicated(c, s, a, b, missing_first) bad_addr = "tcp://127.0.0.1:12345" # Order matters; test both addrs = [bad_addr, a.address] if missing_first else [a.address, bad_addr] - out = await s._gather_on_worker(b.address, {x.key: addrs}) + out = await s.gather_on_worker(b.address, {x.key: addrs}) assert out == set() assert a.data[x.key] == "x" assert b.data[x.key] == "x" @@ -2983,7 +2983,7 @@ async def test_gather_on_worker_bad_sender_replicated(c, s, a, b, missing_first) @gen_cluster(client=True) async def test_gather_on_worker_key_not_on_sender(c, s, a, b): """The only sender for a key does not actually hold it""" - out = await s._gather_on_worker(a.address, {"x": [b.address]}) + out = await s.gather_on_worker(a.address, {"x": [b.address]}) assert out == {"x"} @@ -2998,7 +2998,7 @@ async def test_gather_on_worker_key_not_on_sender_replicated( x = await client.scatter("x", workers=[a.address]) # Order matters; test both addrs = [b.address, a.address] if missing_first else [a.address, b.address] - out = await s._gather_on_worker(c.address, {x.key: addrs}) + out = await s.gather_on_worker(c.address, {x.key: addrs}) assert out == set() assert a.data[x.key] == "x" assert c.data[x.key] == "x" @@ -3015,8 +3015,8 @@ async def test_gather_on_worker_duplicate_task(client, s, a, b, c): assert x.key not in c.data out = await asyncio.gather( - s._gather_on_worker(c.address, {x.key: [a.address]}), - s._gather_on_worker(c.address, {x.key: [b.address]}), + s.gather_on_worker(c.address, {x.key: [a.address]}), + s.gather_on_worker(c.address, {x.key: [b.address]}), ) assert out == [set(), set()] assert c.data[x.key] == "x" @@ -3064,7 +3064,7 @@ async def test_delete_worker_data(c, s, a, b): assert b.data == {y.key: "y"} assert s.tasks.keys() == {x.key, y.key, z.key} - await s._delete_worker_data(a.address, [x.key, y.key]) + await s.delete_worker_data(a.address, [x.key, y.key]) assert a.data == {z.key: "z"} assert b.data == {y.key: "y"} assert s.tasks.keys() == {y.key, z.key} @@ -3078,8 +3078,8 @@ async def test_delete_worker_data_double_delete(c, s, a): """ x, y = await c.scatter(["x", "y"]) await asyncio.gather( - s._delete_worker_data(a.address, [x.key]), - s._delete_worker_data(a.address, [x.key]), + s.delete_worker_data(a.address, [x.key]), + s.delete_worker_data(a.address, [x.key]), ) assert a.data == {y.key: "y"} a_ws = s.workers[a.address] @@ -3094,7 +3094,7 @@ async def test_delete_worker_data_bad_worker(s, a, b): """ await a.close() assert s.workers.keys() == {b.address} - await s._delete_worker_data(a.address, ["x"]) + await s.delete_worker_data(a.address, ["x"]) @pytest.mark.parametrize("bad_first", [False, True]) @@ -3109,7 +3109,7 @@ async def test_delete_worker_data_bad_task(c, s, a, bad_first): assert s.tasks.keys() == {x.key, y.key} keys = ["notexist", x.key] if bad_first else [x.key, "notexist"] - await s._delete_worker_data(a.address, keys) + await s.delete_worker_data(a.address, keys) assert a.data == {y.key: "y"} assert s.tasks.keys() == {y.key} assert s.workers[a.address].nbytes == s.tasks[y.key].nbytes