diff --git a/distributed/stealing.py b/distributed/stealing.py index cc9737796a..101a228ce0 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -2,6 +2,7 @@ import asyncio import logging +import uuid from collections import defaultdict, deque from math import log2 from time import time @@ -233,7 +234,9 @@ def move_task_request(self, ts, victim, thief) -> str: try: if ts in self.in_flight: return "in-flight" - stimulus_id = f"steal-{time()}" + # Stimulus IDs are used to verify the response, see + # `move_task_confirm`. Therefore, this must be truly unique. + stimulus_id = f"steal-{uuid.uuid4().hex}" key = ts.key self.remove_key_from_stealable(ts) @@ -291,7 +294,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): self.in_flight[ts] = d return except KeyError: - self.log(("already-aborted", key, state, stimulus_id)) + self.log(("already-aborted", key, state, worker, stimulus_id)) return thief = d["thief"] diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3b6cd8497b..72bf63c1af 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -13,7 +13,7 @@ import dask -from distributed import Nanny, Worker, wait, worker_client +from distributed import Lock, Nanny, Worker, wait, worker_client from distributed.compatibility import LINUX, WINDOWS from distributed.config import config from distributed.metrics import time @@ -1193,3 +1193,35 @@ async def test_correct_bad_time_estimate(c, s, *workers): assert any(s.tasks[f.key] in steal.key_stealable for f in futures) await wait(futures) assert all(w.data for w in workers), [sorted(w.data) for w in workers] + + +@gen_cluster(client=True) +async def test_steal_stimulus_id_unique(c, s, *workers): + steal = s.extensions["stealing"] + num_futs = 1_000 + async with Lock() as lock: + + def blocked(x, lock): + lock.acquire() + + # Setup all tasks on worker 0 such that victim/thief relation is the + # same for all tasks. + futures = c.map( + blocked, range(num_futs), lock=lock, workers=[workers[0].address] + ) + # Ensure all tasks are assigned to the worker since otherwise the + # move_task_request fails. + while len(workers[0].tasks) != num_futs: + await asyncio.sleep(0.1) + tasks = [s.tasks[f.key] for f in futures] + w0 = s.workers[workers[0].address] + w1 = s.workers[workers[1].address] + # Generating the move task requests as fast as possible increases the + # chance of duplicates if the uniqueness is not guaranteed. + for ts in tasks: + steal.move_task_request(ts, w0, w1) + # Values stored in in_flight are used for response verification. + # Therefore all stimulus IDs are stored here and must be unique + stimulus_ids = {dct["stimulus_id"] for dct in steal.in_flight.values()} + assert len(stimulus_ids) == num_futs + await c.cancel(futures)