Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker key reference counting #3641

Closed
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,7 @@ async def _scatter(
if w.scheduler.address == self.scheduler.address:
direct = True

out = {k: Future(k, self, inform=False) for k in data}
if local_worker: # running within task
local_worker.update_data(data=data, report=False)

Expand Down Expand Up @@ -2024,7 +2025,6 @@ async def _scatter(
timeout=timeout,
)

out = {k: Future(k, self, inform=False) for k in data}
for key, typ in types.items():
self.futures[key].finish(type=typ)

Expand Down
6 changes: 5 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,11 @@ def update_graph(
if any(
dep not in self.tasks and dep not in tasks for dep in deps
): # bad key
logger.info("User asked for computation on lost data, %s", k)
logger.info(
"User asked for computation %s lost data %s",
k,
[dep for dep in deps if dep not in self.tasks],
)
del tasks[k]
del dependencies[k]
if k in keys:
Expand Down
13 changes: 13 additions & 0 deletions distributed/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,16 @@ async def test_submit_different_names(s, a, b):
assert fut > 0
finally:
await c.close()


@gen_cluster(client=True)
async def test_task_unique_groups_scatter(c, s, a, b):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why, but occasionally when I run I get a timeouterror:

E tornado.util.TimeoutError: Operation timed out after 20 seconds

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How occasionally? I'm not seeing this after 10 runs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was seeing it after 3 runs on my machine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried converting Corey's example into a test and now I see that timeouterror every time. I wonder if pytest is masking the CancelledError somehow.

@gen_cluster(client=True)
async def test_task_unique_groups_scatter_tree_reduce(c, s, a, b):
    from toolz import first

    def tree_reduce(objs):
        while len(objs) > 1:
            new_objs = []
            n_objs = len(objs)
            for i in range(0, n_objs, 2):
                inputs = objs[i : i + 2]
                obj = c.submit(sum, inputs)
                new_objs.append(obj)
            wait(new_objs)
            objs = new_objs

        return first(objs)

    for n_parts in [1, 2, 5, 10, 15]:
        a = await c.scatter(range(n_parts))
        b = tree_reduce(a)

        assert sum(range(n_parts)) == await b

        del a

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been seeing this timeout exception as well and what's very strange is that I've even increased the timeout to 50s and I get the exception long before 50s has passed. This is specifically true in the case of dask/dask#6037

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting it almost immediately. Within a second.

""" This test ensure that tasks are correctly deleted when using scatter/submit
"""
n = await c.scatter([0, 1], hash=True)
x = await c.submit(sum, n)
del n
del x
m = await c.scatter([0, 1], hash=True)
y = await c.submit(sum, m)
assert y in list(b.data.values()) or y in list(a.data.values())
17 changes: 16 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(
):
self.tasks = dict()
self.task_state = dict()
self.task_refs = dict()
self.dep_state = dict()
self.dependencies = dict()
self.dependents = dict()
Expand Down Expand Up @@ -1300,6 +1301,11 @@ def update_data(self, comm=None, data=None, report=True, serializers=None):
if key in self.dep_state:
self.transition_dep(key, "memory", value=value)

if key in self.task_refs:
self.task_refs[key] += 1
else:
self.task_refs[key] = 1

self.log.append((key, "receive-from-scatter"))

if report:
Expand All @@ -1308,8 +1314,17 @@ def update_data(self, comm=None, data=None, report=True, serializers=None):
return info

async def delete_data(self, comm=None, keys=None, report=True):
keys = list(keys)
if keys:
for key in list(keys):
for key in keys:
if key in self.task_refs:
self.task_refs[key] -= 1
if self.task_refs[key] > 0:
keys.remove(key)
jacobtomlinson marked this conversation as resolved.
Show resolved Hide resolved
continue
else:
del self.task_refs[key]

self.log.append((key, "delete"))
if key in self.task_state:
self.release_key(key)
Expand Down