diff --git a/distributed/client.py b/distributed/client.py index 170cac9322..cbcac387f9 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2028,6 +2028,7 @@ async def _scatter( if w.scheduler.address == self.scheduler.address: direct = True + out = {k: Future(k, self, inform=False) for k in data} if local_worker: # running within task local_worker.update_data(data=data, report=False) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 09ae20e8f2..fbb98c04bd 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -315,3 +315,16 @@ async def test_submit_different_names(s, a, b): assert fut > 0 finally: await c.close() + + +@gen_cluster(client=True) +async def test_task_unique_groups_scatter(c, s, a, b): + """ This test ensure that tasks are correctly deleted when using scatter/submit + """ + n = await c.scatter([0, 1], hash=True) + x = await c.submit(sum, n) + del n + del x + m = await c.scatter([0, 1], hash=True) + y = await c.submit(sum, m) + assert y in list(b.data.values()) or y in list(a.data.values()) diff --git a/distributed/worker.py b/distributed/worker.py index 8dff85e653..19ed2407a4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -334,6 +334,7 @@ def __init__( ): self.tasks = dict() self.task_state = dict() + self.task_refs = dict() self.dep_state = dict() self.dependencies = dict() self.dependents = dict() @@ -1326,6 +1327,11 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): if key in self.dep_state: self.transition_dep(key, "memory", value=value) + if key in self.task_refs: + self.task_refs[key] += 1 + else: + self.task_refs[key] = 1 + self.log.append((key, "receive-from-scatter")) if report: @@ -1334,8 +1340,17 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): return info async def delete_data(self, comm=None, keys=None, report=True): + keys = list(keys) if keys: for key in list(keys): + if key in self.task_refs: + self.task_refs[key] -= 1 + if self.task_refs[key] > 0: + keys.remove(key) + continue + else: + del self.task_refs[key] + self.log.append((key, "delete")) if key in self.task_state: self.release_key(key)