Skip to content

Commit

Permalink
Refactor gather_dep (#6388)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 10, 2022
1 parent 6272e20 commit df1eaba
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 95 deletions.
5 changes: 2 additions & 3 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ async def test_worker_story_with_deps(c, s, a, b):
# Story now includes randomized stimulus_ids and timestamps.
story = b.story("res")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task", "task-finished"}

assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"}
# This is a simple transition log
expected = [
("res", "compute-task", "released"),
Expand All @@ -153,7 +152,7 @@ async def test_worker_story_with_deps(c, s, a, b):

story = b.story("dep")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task"}
assert stimulus_ids == {"compute-task", "gather-dep-success"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "released", "fetch", "fetch", {}),
Expand Down
1 change: 0 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2928,7 +2928,6 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test"))
await f2

assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
assert a.tasks[f1.key].suspicious_count == 0
assert s.tasks[f1.key].suspicious == 0

Expand Down
35 changes: 35 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,38 @@ async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3):
assert w3.tasks["x"].state == "missing"
assert w3.tasks["y"].state == "flight"
assert w3.tasks["y"].who_has == {w2.address}


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_fetch_to_missing_on_network_failure(c, s, a):
"""
1. Two tasks, x and y, are respectively in flight and fetch state from the same
worker, which holds the only replica of both.
2. gather_dep for x returns GatherDepNetworkFailureEvent
3. The event empties has_what, x.who_has, and y.who_has.
4. The same event invokes _ensure_communicating, which pops y from data_needed
- but y has an empty who_has, which is an exceptional situation.
_ensure_communicating recommends a transition to missing for x.
5. The fetch->missing transition is executed, but y is no longer in data_needed -
another exceptional situation.
"""
block_get_data = asyncio.Event()

class BlockedBreakingWorker(Worker):
async def get_data(self, comm, *args, **kwargs):
await block_get_data.wait()
raise OSError("fake error")

async with BlockedBreakingWorker(s.address) as b:
x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, 2, key="y", workers=[b.address])
await wait([x, y])
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test_x")
await wait_for_state("x", "flight", a)
s.request_acquire_replicas(a.address, ["y"], stimulus_id="test_y")
await wait_for_state("y", "fetch", a)

block_get_data.set()

await wait_for_state("x", "missing", a)
await wait_for_state("y", "missing", a)
238 changes: 148 additions & 90 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Collection,
Container,
Iterable,
Iterator,
Mapping,
MutableMapping,
)
Expand Down Expand Up @@ -122,7 +123,11 @@
FindMissingEvent,
FreeKeysEvent,
GatherDep,
GatherDepBusyEvent,
GatherDepDoneEvent,
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
Instructions,
InvalidTaskState,
InvalidTransition,
Expand Down Expand Up @@ -2185,13 +2190,7 @@ def transition_fetch_flight(
def transition_fetch_missing(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
# There's a use case where ts won't be found in self.data_needed, so
# `self.data_needed.remove(ts)` would crash:
# 1. An event handler empties who_has and pushes a recommendation to missing
# 2. The same event handler calls _ensure_communicating, which pops the task
# from data_needed
# 3. The recommendation is enacted
# See matching code in _ensure_communicating.
# _ensure_communicating could have just popped this task out of data_needed
self.data_needed.discard(ts)
return self.transition_generic_missing(ts, stimulus_id=stimulus_id)

Expand Down Expand Up @@ -3017,11 +3016,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
assert self.address not in ts.who_has

if not ts.who_has:
# An event handler just emptied who_has and recommended a fetch->missing
# transition. Then, the same handler called _ensure_communicating. The
# transition hasn't been enacted yet, so the task is still in fetch
# state and in data_needed.
# See matching code in transition_fetch_missing.
recommendations[ts] = "missing"
continue

workers = [
Expand Down Expand Up @@ -3293,13 +3288,6 @@ async def gather_dep(
if self.status not in WORKER_ANY_RUNNING:
return None

recommendations: Recs = {}
instructions: Instructions = []
response = {}

def done_event():
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")

try:
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
logger.debug("Request %d keys from %s", len(to_gather), worker)
Expand All @@ -3310,8 +3298,14 @@ def done_event():
)
stop = time()
if response["status"] == "busy":
return done_event()
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
return GatherDepBusyEvent(
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-busy-{time()}",
)

assert response["status"] == "OK"
cause = self._get_cause(to_gather)
self._update_metrics_received_data(
start=start,
Expand All @@ -3323,86 +3317,156 @@ def done_event():
self.log.append(
("receive-dep", worker, set(response["data"]), stimulus_id, time())
)
return done_event()
return GatherDepSuccessEvent(
worker=worker,
total_nbytes=total_nbytes,
data=response["data"],
stimulus_id=f"gather-dep-success-{time()}",
)

except OSError:
logger.exception("Worker stream died during communication: %s", worker)
has_what = self.has_what.pop(worker)
self.data_needed_per_worker.pop(worker)
self.log.append(
("receive-dep-failed", worker, has_what, stimulus_id, time())
("receive-dep-failed", worker, to_gather, stimulus_id, time())
)
return GatherDepNetworkFailureEvent(
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-network-failure-{time()}",
)
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(worker)
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", worker, ts.key, stimulus_id, time())
)
return done_event()

except Exception as e:
# e.g. data failed to deserialize
logger.exception(e)
if self.batched_stream and LOG_PDB:
import pdb

pdb.set_trace()
msg = error_message(e)
for k in self.in_flight_workers[worker]:
ts = self.tasks[k]
recommendations[ts] = tuple(msg.values())
return done_event()

finally:
self.comm_nbytes -= total_nbytes
busy = response.get("status", "") == "busy"
data = response.get("data", {})
return GatherDepFailureEvent.from_exception(
e,
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"gather-dep-failure-{time()}",
)

if busy:
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(worker)
instructions.append(
RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id)
)
def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]:
"""Common code for all subclasses of GatherDepDoneEvent.
refresh_who_has = []

for d in self.in_flight_workers.pop(worker):
ts = self.tasks[d]
ts.done = True
if d in data:
recommendations[ts] = ("memory", data[d])
elif busy:
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(d)
elif ts not in recommendations:
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.data_needed_per_worker[worker].discard(ts)
self.log.append((d, "missing-dep", stimulus_id, time()))
recommendations[ts] = "fetch"

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has,
stimulus_id=f"gather-dep-busy-{time()}",
)
Yields the tasks that need to transition out of flight.
"""
self.comm_nbytes -= ev.total_nbytes
keys = self.in_flight_workers.pop(ev.worker)
for key in keys:
ts = self.tasks[key]
ts.done = True
yield ts

@_handle_event.register
def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
"""gather_dep terminated successfully.
The response may contain less keys than the request.
"""
recommendations: Recs = {}
for ts in self._gather_dep_done_common(ev):
if ts.key in ev.data:
recommendations[ts] = ("memory", ev.data[ts.key])
else:
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
if self.validate:
assert ts.state != "fetch"
assert ts not in self.data_needed_per_worker[ev.worker]
ts.who_has.discard(ev.worker)
self.has_what[ev.worker].discard(ts.key)
recommendations[ts] = "fetch"

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
"""gather_dep terminated: remote worker is busy"""
# Avoid hammering the worker. If there are multiple replicas
# available, immediately try fetching from a different worker.
self.busy_workers.add(ev.worker)

recommendations: Recs = {}
refresh_who_has = []
for ts in self._gather_dep_done_common(ev):
recommendations[ts] = "fetch"
if not ts.who_has - self.busy_workers:
refresh_who_has.append(ts.key)

instructions: Instructions = [
RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id),
]

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
# Try querying the scheduler for unknown ones.
instructions.append(
RequestRefreshWhoHasMsg(
keys=refresh_who_has, stimulus_id=ev.stimulus_id
)
)

self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)
return merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_network_failure(
self, ev: GatherDepNetworkFailureEvent
) -> RecsInstrs:
"""gather_dep terminated: network failure while trying to
communicate with remote worker
Though the network failure could be transient, we assume it is not, and
preemptively act as though the other worker has died (including removing all
keys from it, even ones we did not fetch).
This optimization leads to faster completion of the fetch, since we immediately
either retry a different worker, or ask the scheduler to inform us of a new
worker if no other worker is available.
"""
self.data_needed_per_worker.pop(ev.worker)
for key in self.has_what.pop(ev.worker):
ts = self.tasks[key]
ts.who_has.discard(ev.worker)

recommendations: Recs = {}
for ts in self._gather_dep_done_common(ev):
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
recommendations[ts] = "fetch"

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
"""gather_dep terminated: generic error raised (not a network failure);
e.g. data failed to deserialize.
"""
recommendations: Recs = {
ts: (
"error",
ev.exception,
ev.traceback,
ev.exception_text,
ev.traceback_text,
)
for ts in self._gather_dep_done_common(ev)
}

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
await asyncio.sleep(0.15)
Expand Down Expand Up @@ -3841,11 +3905,6 @@ def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs:
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)

@_handle_event.register
def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs:
"""Temporary hack - to be removed"""
return self._ensure_communicating(stimulus_id=ev.stimulus_id)

@_handle_event.register
def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
self.busy_workers.discard(ev.worker)
Expand Down Expand Up @@ -4181,8 +4240,7 @@ def validate_task_fetch(self, ts):
assert self.address not in ts.who_has
assert not ts.done
assert ts in self.data_needed
assert ts.who_has

# Note: ts.who_has may be empty; see GatherDepNetworkFailureEvent
for w in ts.who_has:
assert ts.key in self.has_what[w]
assert ts in self.data_needed_per_worker[w]
Expand Down
Loading

0 comments on commit df1eaba

Please sign in to comment.