Skip to content

Commit

Permalink
Run multiple AMMs in parallel (#5315) (#5339)
Browse files Browse the repository at this point in the history
Propaedeutic to RetireWorker AMM policy
  • Loading branch information
crusaderky authored Sep 22, 2021
1 parent 23c3b4b commit e001822
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
25 changes: 17 additions & 8 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,20 @@ def __init__(
interval: float | None = None,
):
self.scheduler = scheduler
self.policies = set()

if policies is None:
# Initialize policies from config
policies = set()
for kwargs in dask.config.get(
"distributed.scheduler.active-memory-manager.policies"
):
kwargs = kwargs.copy()
cls = import_term(kwargs.pop("class"))
if not issubclass(cls, ActiveMemoryManagerPolicy):
raise TypeError(
f"{cls}: Expected ActiveMemoryManagerPolicy; got {type(cls)}"
)
policies.add(cls(**kwargs))

for policy in policies:
policy.manager = self
self.policies = policies
self.add_policy(policy)

if register:
scheduler.extensions["amm"] = self
Expand All @@ -92,16 +89,28 @@ def __init__(

def start(self, comm=None) -> None:
"""Start executing every ``self.interval`` seconds until scheduler shutdown"""
if self.started:
return
pc = PeriodicCallback(self.run_once, self.interval * 1000.0)
self.scheduler.periodic_callbacks["amm"] = pc
self.scheduler.periodic_callbacks[f"amm-{id(self)}"] = pc
pc.start()

def stop(self, comm=None) -> None:
"""Stop periodic execution"""
pc = self.scheduler.periodic_callbacks.pop("amm", None)
pc = self.scheduler.periodic_callbacks.pop(f"amm-{id(self)}", None)
if pc:
pc.stop()

@property
def started(self) -> bool:
return f"amm-{id(self)}" in self.scheduler.periodic_callbacks

def add_policy(self, policy: ActiveMemoryManagerPolicy) -> None:
if not isinstance(policy, ActiveMemoryManagerPolicy):
raise TypeError(f"Expected ActiveMemoryManagerPolicy; got {policy!r}")
self.policies.add(policy)
policy.manager = self

def run_once(self, comm=None) -> None:
"""Run all policies once and asynchronously (fire and forget) enact their
recommendations to replicate/drop keys
Expand Down
59 changes: 59 additions & 0 deletions distributed/tests/test_active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,65 @@ async def test_auto_start(c, s, a, b):
assert len(s.tasks["x"].who_has) == 1


@gen_cluster(client=True, config=demo_config("drop", key="x"))
async def test_add_policy(c, s, a, b):
p2 = DemoPolicy(action="drop", key="y", n=10, candidates=None)
p3 = DemoPolicy(action="drop", key="z", n=10, candidates=None)

# policies parameter can be:
# - None: get from config
# - explicit set, which can be empty
m1 = s.extensions["amm"]
m2 = ActiveMemoryManagerExtension(s, {p2}, register=False, start=False)
m3 = ActiveMemoryManagerExtension(s, set(), register=False, start=False)

assert len(m1.policies) == 1
assert len(m2.policies) == 1
assert len(m3.policies) == 0
m3.add_policy(p3)
assert len(m3.policies) == 1

futures = await c.scatter({"x": 1, "y": 2, "z": 3}, broadcast=True)
m1.run_once()
while len(s.tasks["x"].who_has) == 2:
await asyncio.sleep(0.01)

m2.run_once()
while len(s.tasks["y"].who_has) == 2:
await asyncio.sleep(0.01)

m3.run_once()
while len(s.tasks["z"].who_has) == 2:
await asyncio.sleep(0.01)


@gen_cluster(client=True, config=demo_config("drop", key="x", start=False))
async def test_multi_start(c, s, a, b):
"""Multiple AMMs can be started in parallel"""
p2 = DemoPolicy(action="drop", key="y", n=10, candidates=None)
p3 = DemoPolicy(action="drop", key="z", n=10, candidates=None)

# policies parameter can be:
# - None: get from config
# - explicit set, which can be empty
m1 = s.extensions["amm"]
m2 = ActiveMemoryManagerExtension(s, {p2}, register=False, start=True, interval=0.1)
m3 = ActiveMemoryManagerExtension(s, {p3}, register=False, start=True, interval=0.1)

assert not m1.started
assert m2.started
assert m3.started

futures = await c.scatter({"x": 1, "y": 2, "z": 3}, broadcast=True)

# The AMMs should run within 0.1s of the broadcast.
# Add generous extra padding to prevent flakiness.
await asyncio.sleep(0.5)
assert len(s.tasks["x"].who_has) == 2
assert len(s.tasks["y"].who_has) == 1
assert len(s.tasks["z"].who_has) == 1


@gen_cluster(client=True, config=NO_AMM_START)
async def test_not_registered(c, s, a, b):
futures = await c.scatter({"x": 1}, broadcast=True)
Expand Down

0 comments on commit e001822

Please sign in to comment.