Skip to content

Commit

Permalink
Remove release handler from scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jul 19, 2021
1 parent 30ad02b commit 2931a4f
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 93 deletions.
57 changes: 0 additions & 57 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3570,7 +3570,6 @@ def __init__(
worker_handlers = {
"task-finished": self.handle_task_finished,
"task-erred": self.handle_task_erred,
"release": self.handle_release_data,
"release-worker-data": self.release_worker_data,
"add-keys": self.add_keys,
"missing-data": self.handle_missing_data,
Expand Down Expand Up @@ -4566,43 +4565,6 @@ def stimulus_task_erred(

return recommendations, client_msgs, worker_msgs

def stimulus_missing_data(
self, cause=None, key=None, worker=None, ensure=True, **kwargs
):
"""Mark that certain keys have gone missing. Recover."""
parent: SchedulerState = cast(SchedulerState, self)
with log_errors():
logger.debug("Stimulus missing data %s, %s", key, worker)

recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}

ts: TaskState = parent._tasks.get(key)
if ts is None or ts._state == "memory":
return recommendations, client_msgs, worker_msgs
cts: TaskState = parent._tasks.get(cause)

if cts is not None and cts._state == "memory": # couldn't find this
ws: WorkerState
cts_nbytes: Py_ssize_t = cts.get_nbytes()
for ws in cts._who_has: # TODO: this behavior is extreme
del ws._has_what[ts]
ws._nbytes -= cts_nbytes
cts._who_has.clear()
recommendations[cause] = "released"

if key:
recommendations[key] = "released"

parent._transitions(recommendations, client_msgs, worker_msgs)
recommendations = {}

if parent._validate:
assert cause not in self.who_has

return recommendations, client_msgs, worker_msgs

def stimulus_retry(self, comm=None, keys=None, client=None):
parent: SchedulerState = cast(SchedulerState, self)
logger.info("Client %s requests to retry %d keys", client, len(keys))
Expand Down Expand Up @@ -5143,25 +5105,6 @@ def handle_task_erred(self, key=None, **msg):

self.send_all(client_msgs, worker_msgs)

def handle_release_data(self, key=None, worker=None, client=None, **msg):
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState = parent._tasks.get(key)
if ts is None:
return
ws: WorkerState = parent._workers_dv[worker]
if ts._processing_on != ws:
return

recommendations: dict
client_msgs: dict
worker_msgs: dict

r: tuple = self.stimulus_missing_data(key=key, ensure=False, **msg)
recommendations, client_msgs, worker_msgs = r
parent._transitions(recommendations, client_msgs, worker_msgs)

self.send_all(client_msgs, worker_msgs)

def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
parent: SchedulerState = cast(SchedulerState, self)
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
Expand Down
34 changes: 11 additions & 23 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,24 +825,20 @@ async def test_long_tasks_dont_trigger_timeout(c, s, a, b):
await x


@pytest.mark.skip
@gen_cluster(client=True)
async def test_missing_data_heals(c, s, a, b):
a.validate = False
b.validate = False
x = c.submit(inc, 1)
y = c.submit(inc, x)
z = c.submit(inc, y)

await wait([x, y, z])

# Secretly delete y's key
if y.key in a.data:
del a.data[y.key]
a.release_key(y.key)
if y.key in b.data:
del b.data[y.key]
b.release_key(y.key)
for w in [a, b]:
w.handle_free_keys(keys=(y.key,))
assert y.key not in w.data
assert y.key not in w.tasks

await asyncio.sleep(0)

w = c.submit(add, y, z)
Expand All @@ -851,30 +847,23 @@ async def test_missing_data_heals(c, s, a, b):
assert result == 3 + 4


@pytest.mark.skip
@gen_cluster(client=True)
async def test_gather_robust_to_missing_data(c, s, a, b):
a.validate = False
b.validate = False
x, y, z = c.map(inc, range(3))
await wait([x, y, z]) # everything computed

for f in [x, y]:
for w in [a, b]:
if f.key in w.data:
del w.data[f.key]
await asyncio.sleep(0)
w.release_key(f.key)
w.handle_free_keys(keys=(f.key,))
assert f.key not in w.data
assert f.key not in w.tasks

xx, yy, zz = await c.gather([x, y, z])
assert (xx, yy, zz) == (1, 2, 3)


@pytest.mark.skip
@gen_cluster(client=True)
async def test_gather_robust_to_nested_missing_data(c, s, a, b):
a.validate = False
b.validate = False
w = c.submit(inc, 1)
x = c.submit(inc, w)
y = c.submit(inc, x)
Expand All @@ -884,10 +873,9 @@ async def test_gather_robust_to_nested_missing_data(c, s, a, b):

for worker in [a, b]:
for datum in [y, z]:
if datum.key in worker.data:
del worker.data[datum.key]
await asyncio.sleep(0)
worker.release_key(datum.key)
worker.handle_free_keys(keys=(datum.key,))
assert datum.key not in worker.data
assert datum.key not in worker.tasks

result = await c.gather([z])

Expand Down
42 changes: 42 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,45 @@ async def test_blacklist_shuffle_split(c, s, a, b):
assert "split" not in ts.prefix.name
await asyncio.sleep(0.001)
await res


@pytest.mark.parametrize("num_workers", [2, 3])
def test_steal_while_closing(num_workers):
@gen_cluster(client=True, nthreads=[("", 1)] * num_workers)
async def test(c, s, *workers):

futures = c.map(
slowinc,
range(50),
delay=0.01,
workers=workers[0].address,
allow_other_workers=True,
key=[f"f-{x:02d}" for x in range(50)],
)

while sum(len(w.tasks) for w in workers[1:]) < 10:
await asyncio.sleep(0.01)

await workers[-1].close()

await c.gather(futures)

# The scheduler should only initiate transitions for the closed worker
# but nothing else

## Ordinary transition if everything works as expected
#
# released -> waiting
# waiting -> processing
# processing -> memory

## For tasks on the dying worker everything is prepended with
## three more transitions
# memory -> released (sometimes)
# released -> waiting
# waiting -> processing
# processing -> released
# + standard transitions
assert max(len(s.story(f.key)) for f in futures) <= 7

test()
22 changes: 9 additions & 13 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,19 +2653,15 @@ def release_key(
self.available_resources[resource] += quantity

if report:
# Inform the scheduler of keys which will have gone missing
# We are releasing them before they have completed
if ts.state in PROCESSING:
# This path is only hit with work stealing
msg = {"op": "release", "key": key, "cause": cause}
else:
# This path is only hit when calling release_key manually
msg = {
"op": "release-worker-data",
"keys": [key],
"worker": self.address,
}
self.batched_stream.send(msg)
# TODO: Is this conditional check for task state necessary?
if ts.state not in PROCESSING:
self.batched_stream.send(
{
"op": "release-worker-data",
"keys": [key],
"worker": self.address,
}
)

self._notify_plugins("release_key", key, ts.state, cause, reason, report)
del self.tasks[key]
Expand Down

0 comments on commit 2931a4f

Please sign in to comment.