Skip to content

Commit

Permalink
initial based on dask#3641
Browse files Browse the repository at this point in the history
  • Loading branch information
kumarprabhu1988 committed Jun 4, 2020
1 parent 1fe50c2 commit 41ad13b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,7 @@ async def _scatter(
if w.scheduler.address == self.scheduler.address:
direct = True

out = {k: Future(k, self, inform=False) for k in data}
if local_worker: # running within task
local_worker.update_data(data=data, report=False)

Expand Down
13 changes: 13 additions & 0 deletions distributed/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,16 @@ async def test_submit_different_names(s, a, b):
assert fut > 0
finally:
await c.close()


@gen_cluster(client=True)
async def test_task_unique_groups_scatter(c, s, a, b):
""" This test ensure that tasks are correctly deleted when using scatter/submit
"""
n = await c.scatter([0, 1], hash=True)
x = await c.submit(sum, n)
del n
del x
m = await c.scatter([0, 1], hash=True)
y = await c.submit(sum, m)
assert y in list(b.data.values()) or y in list(a.data.values())
15 changes: 15 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(
):
self.tasks = dict()
self.task_state = dict()
self.task_refs = dict()
self.dep_state = dict()
self.dependencies = dict()
self.dependents = dict()
Expand Down Expand Up @@ -1326,6 +1327,11 @@ def update_data(self, comm=None, data=None, report=True, serializers=None):
if key in self.dep_state:
self.transition_dep(key, "memory", value=value)

if key in self.task_refs:
self.task_refs[key] += 1
else:
self.task_refs[key] = 1

self.log.append((key, "receive-from-scatter"))

if report:
Expand All @@ -1334,8 +1340,17 @@ def update_data(self, comm=None, data=None, report=True, serializers=None):
return info

async def delete_data(self, comm=None, keys=None, report=True):
keys = list(keys)
if keys:
for key in list(keys):
if key in self.task_refs:
self.task_refs[key] -= 1
if self.task_refs[key] > 0:
keys.remove(key)
continue
else:
del self.task_refs[key]

self.log.append((key, "delete"))
if key in self.task_state:
self.release_key(key)
Expand Down

0 comments on commit 41ad13b

Please sign in to comment.