Skip to content

Commit

Permalink
Remove release handler from scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jul 19, 2021
1 parent 30ad02b commit e9e42cb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 69 deletions.
57 changes: 0 additions & 57 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3570,7 +3570,6 @@ def __init__(
worker_handlers = {
"task-finished": self.handle_task_finished,
"task-erred": self.handle_task_erred,
"release": self.handle_release_data,
"release-worker-data": self.release_worker_data,
"add-keys": self.add_keys,
"missing-data": self.handle_missing_data,
Expand Down Expand Up @@ -4566,43 +4565,6 @@ def stimulus_task_erred(

return recommendations, client_msgs, worker_msgs

def stimulus_missing_data(
self, cause=None, key=None, worker=None, ensure=True, **kwargs
):
"""Mark that certain keys have gone missing. Recover."""
parent: SchedulerState = cast(SchedulerState, self)
with log_errors():
logger.debug("Stimulus missing data %s, %s", key, worker)

recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}

ts: TaskState = parent._tasks.get(key)
if ts is None or ts._state == "memory":
return recommendations, client_msgs, worker_msgs
cts: TaskState = parent._tasks.get(cause)

if cts is not None and cts._state == "memory": # couldn't find this
ws: WorkerState
cts_nbytes: Py_ssize_t = cts.get_nbytes()
for ws in cts._who_has: # TODO: this behavior is extreme
del ws._has_what[ts]
ws._nbytes -= cts_nbytes
cts._who_has.clear()
recommendations[cause] = "released"

if key:
recommendations[key] = "released"

parent._transitions(recommendations, client_msgs, worker_msgs)
recommendations = {}

if parent._validate:
assert cause not in self.who_has

return recommendations, client_msgs, worker_msgs

def stimulus_retry(self, comm=None, keys=None, client=None):
parent: SchedulerState = cast(SchedulerState, self)
logger.info("Client %s requests to retry %d keys", client, len(keys))
Expand Down Expand Up @@ -5143,25 +5105,6 @@ def handle_task_erred(self, key=None, **msg):

self.send_all(client_msgs, worker_msgs)

def handle_release_data(self, key=None, worker=None, client=None, **msg):
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState = parent._tasks.get(key)
if ts is None:
return
ws: WorkerState = parent._workers_dv[worker]
if ts._processing_on != ws:
return

recommendations: dict
client_msgs: dict
worker_msgs: dict

r: tuple = self.stimulus_missing_data(key=key, ensure=False, **msg)
recommendations, client_msgs, worker_msgs = r
parent._transitions(recommendations, client_msgs, worker_msgs)

self.send_all(client_msgs, worker_msgs)

def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
parent: SchedulerState = cast(SchedulerState, self)
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
Expand Down
42 changes: 42 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,45 @@ async def test_blacklist_shuffle_split(c, s, a, b):
assert "split" not in ts.prefix.name
await asyncio.sleep(0.001)
await res


@pytest.mark.parametrize("num_workers", [2, 3])
def test_steal_while_closing(num_workers):
@gen_cluster(client=True, nthreads=[("", 1)] * num_workers)
async def test(c, s, *workers):

futures = c.map(
slowinc,
range(50),
delay=0.01,
workers=workers[0].address,
allow_other_workers=True,
key=[f"f-{x:02d}" for x in range(50)],
)

while sum(len(w.tasks) for w in workers[1:]) < 10:
await asyncio.sleep(0.01)

await workers[-1].close()

await c.gather(futures)

# The scheduler should only initiate transitions for the closed worker
# but nothing else

## Ordinary transition if everything works as expected
#
# released -> waiting
# waiting -> processing
# processing -> memory

## For tasks on the dying worker everything is prepended with
## three more transitions
# memory -> released (sometimes)
# released -> waiting
# waiting -> processing
# processing -> released
# + standard transitions
assert max(len(s.story(f.key)) for f in futures) <= 7

test()
18 changes: 6 additions & 12 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,18 +2653,12 @@ def release_key(
self.available_resources[resource] += quantity

if report:
# Inform the scheduler of keys which will have gone missing
# We are releasing them before they have completed
if ts.state in PROCESSING:
# This path is only hit with work stealing
msg = {"op": "release", "key": key, "cause": cause}
else:
# This path is only hit when calling release_key manually
msg = {
"op": "release-worker-data",
"keys": [key],
"worker": self.address,
}
# This path is only hit when calling release_key manually
msg = {
"op": "release-worker-data",
"keys": [key],
"worker": self.address,
}
self.batched_stream.send(msg)

self._notify_plugins("release_key", key, ts.state, cause, reason, report)
Expand Down

0 comments on commit e9e42cb

Please sign in to comment.