diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 098773c224..c2bbe7ccd3 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -52,23 +52,20 @@ def __init__( interval: float | None = None, ): self.scheduler = scheduler + self.policies = set() if policies is None: + # Initialize policies from config policies = set() for kwargs in dask.config.get( "distributed.scheduler.active-memory-manager.policies" ): kwargs = kwargs.copy() cls = import_term(kwargs.pop("class")) - if not issubclass(cls, ActiveMemoryManagerPolicy): - raise TypeError( - f"{cls}: Expected ActiveMemoryManagerPolicy; got {type(cls)}" - ) policies.add(cls(**kwargs)) for policy in policies: - policy.manager = self - self.policies = policies + self.add_policy(policy) if register: scheduler.extensions["amm"] = self @@ -92,16 +89,28 @@ def __init__( def start(self, comm=None) -> None: """Start executing every ``self.interval`` seconds until scheduler shutdown""" + if self.started: + return pc = PeriodicCallback(self.run_once, self.interval * 1000.0) - self.scheduler.periodic_callbacks["amm"] = pc + self.scheduler.periodic_callbacks[f"amm-{id(self)}"] = pc pc.start() def stop(self, comm=None) -> None: """Stop periodic execution""" - pc = self.scheduler.periodic_callbacks.pop("amm", None) + pc = self.scheduler.periodic_callbacks.pop(f"amm-{id(self)}", None) if pc: pc.stop() + @property + def started(self) -> bool: + return f"amm-{id(self)}" in self.scheduler.periodic_callbacks + + def add_policy(self, policy: ActiveMemoryManagerPolicy) -> None: + if not isinstance(policy, ActiveMemoryManagerPolicy): + raise TypeError(f"Expected ActiveMemoryManagerPolicy; got {policy!r}") + self.policies.add(policy) + policy.manager = self + def run_once(self, comm=None) -> None: """Run all policies once and asynchronously (fire and forget) enact their recommendations to replicate/drop keys diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index afe6c11494..c7c747b850 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -101,6 +101,65 @@ async def test_auto_start(c, s, a, b): assert len(s.tasks["x"].who_has) == 1 +@gen_cluster(client=True, config=demo_config("drop", key="x")) +async def test_add_policy(c, s, a, b): + p2 = DemoPolicy(action="drop", key="y", n=10, candidates=None) + p3 = DemoPolicy(action="drop", key="z", n=10, candidates=None) + + # policies parameter can be: + # - None: get from config + # - explicit set, which can be empty + m1 = s.extensions["amm"] + m2 = ActiveMemoryManagerExtension(s, {p2}, register=False, start=False) + m3 = ActiveMemoryManagerExtension(s, set(), register=False, start=False) + + assert len(m1.policies) == 1 + assert len(m2.policies) == 1 + assert len(m3.policies) == 0 + m3.add_policy(p3) + assert len(m3.policies) == 1 + + futures = await c.scatter({"x": 1, "y": 2, "z": 3}, broadcast=True) + m1.run_once() + while len(s.tasks["x"].who_has) == 2: + await asyncio.sleep(0.01) + + m2.run_once() + while len(s.tasks["y"].who_has) == 2: + await asyncio.sleep(0.01) + + m3.run_once() + while len(s.tasks["z"].who_has) == 2: + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, config=demo_config("drop", key="x", start=False)) +async def test_multi_start(c, s, a, b): + """Multiple AMMs can be started in parallel""" + p2 = DemoPolicy(action="drop", key="y", n=10, candidates=None) + p3 = DemoPolicy(action="drop", key="z", n=10, candidates=None) + + # policies parameter can be: + # - None: get from config + # - explicit set, which can be empty + m1 = s.extensions["amm"] + m2 = ActiveMemoryManagerExtension(s, {p2}, register=False, start=True, interval=0.1) + m3 = ActiveMemoryManagerExtension(s, {p3}, register=False, start=True, interval=0.1) + + assert not m1.started + assert m2.started + assert m3.started + + futures = await c.scatter({"x": 1, "y": 2, "z": 3}, broadcast=True) + + # The AMMs should run within 0.1s of the broadcast. + # Add generous extra padding to prevent flakiness. + await asyncio.sleep(0.5) + assert len(s.tasks["x"].who_has) == 2 + assert len(s.tasks["y"].who_has) == 1 + assert len(s.tasks["z"].who_has) == 1 + + @gen_cluster(client=True, config=NO_AMM_START) async def test_not_registered(c, s, a, b): futures = await c.scatter({"x": 1}, broadcast=True)