Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression where unknown tasks were allowed to be stolen #5392

Merged
merged 2 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 74 additions & 75 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,15 @@ class WorkerState:
.. attribute:: processing: {TaskState: cost}

A dictionary of tasks that have been submitted to this worker.
Each task state is asssociated with the expected cost in seconds
Each task state is associated with the expected cost in seconds
of running that task, summing both the task's expected computation
time and the expected communication time of its result.

If a task is already executing on the worker and the excecution time is
twice the learned average TaskGroup duration, this will be set to twice
the current executing time. If the task is unknown, the default task
duration is used instead of the TaskGroup average.

Multiple tasks may be submitted to a worker in advance and the worker
will run them eventually, depending on its execution resources
(but see :doc:`work-stealing`).
Expand Down Expand Up @@ -900,6 +905,18 @@ def name(self) -> str:
def all_durations(self) -> "defaultdict[str, float]":
return self._all_durations

@ccall
@exceptval(check=False)
def add_duration(self, action: str, start: double, stop: double):
duration = stop - start
self._all_durations[action] += duration
if action == "compute":
old = self._duration_average
if old < 0:
self._duration_average = duration
else:
self._duration_average = 0.5 * duration + 0.5 * old

@property
def duration_average(self) -> double:
return self._duration_average
Expand Down Expand Up @@ -1062,6 +1079,18 @@ def nbytes_total(self):
def duration(self) -> double:
return self._duration

@ccall
@exceptval(check=False)
def add_duration(self, action: str, start: double, stop: double):
duration = stop - start
self._all_durations[action] += duration
if action == "compute":
if self._stop < stop:
self._stop = stop
self._start = self._start or start
self._duration += duration
self._prefix.add_duration(action, start, stop)

@property
def types(self) -> set:
return self._types
Expand Down Expand Up @@ -2568,6 +2597,8 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double:
If a task takes longer than twice the current average duration we
estimate the task duration to be 2x current-runtime, otherwise we set it
to be the average duration.

See also ``_remove_from_processing``
"""
exec_time: double = ws._executing.get(ts, 0)
duration: double = self.get_task_duration(ts)
Expand All @@ -2577,7 +2608,11 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double:
else:
comm: double = self.get_comm_cost(ts, ws)
total_duration = duration + comm
old = ws._processing.get(ts, 0)
ws._processing[ts] = total_duration
self._total_occupancy += total_duration - old
ws._occupancy += total_duration - old

return total_duration

def transition_waiting_processing(self, key):
Expand All @@ -2602,10 +2637,8 @@ def transition_waiting_processing(self, key):
return recommendations, client_msgs, worker_msgs
worker = ws._address

duration_estimate = self.set_duration_estimate(ts, ws)
self.set_duration_estimate(ts, ws)
ts._processing_on = ws
ws._occupancy += duration_estimate
self._total_occupancy += duration_estimate
ts.state = "processing"
self.consume_resources(ts, ws)
self.check_idle_saturated(ws)
Expand Down Expand Up @@ -2684,7 +2717,6 @@ def transition_processing_memory(
worker_msgs: dict = {}
try:
ts: TaskState = self._tasks[key]
tg: TaskGroup = ts._group

assert worker
assert isinstance(worker, str)
Expand Down Expand Up @@ -2719,57 +2751,26 @@ def transition_processing_memory(
}
]

has_compute_startstop: bool = False
compute_start: double
compute_stop: double
#############################
# Update Timing Information #
#############################
if startstops:
startstop: dict
for startstop in startstops:
stop = startstop["stop"]
start = startstop["start"]
action = startstop["action"]
if not has_compute_startstop and action == "compute":
compute_start = start
compute_stop = stop
has_compute_startstop = True

# record timings of all actions -- a cheaper way of
# getting timing info compared with get_task_stream()
ts._prefix._all_durations[action] += stop - start
tg._all_durations[action] += stop - start
ts._group.add_duration(
stop=startstop["stop"],
start=startstop["start"],
action=startstop["action"],
)

#############################
# Update Timing Information #
#############################
if has_compute_startstop and ws._processing.get(ts, True):
# Update average task duration for worker
old_duration: double = ts._prefix._duration_average
new_duration: double = compute_stop - compute_start
avg_duration: double
if old_duration < 0:
avg_duration = new_duration
else:
avg_duration = 0.5 * old_duration + 0.5 * new_duration

ts._prefix._duration_average = avg_duration
tg._duration += new_duration
tg._start = tg._start or compute_start
if tg._stop < compute_stop:
tg._stop = compute_stop

s: set = self._unknown_durations.pop(ts._prefix._name, None)
tts: TaskState
if s:
for tts in s:
if tts._processing_on is not None:
wws = tts._processing_on
comm: double = self.get_comm_cost(tts, wws)
old: double = wws._processing[tts]
new: double = avg_duration + comm
diff: double = new - old
wws._processing[tts] = new
wws._occupancy += diff
self._total_occupancy += diff
Comment on lines -2762 to -2772
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regression was partially introduced since the occupancy update here would not put a previously unknown task into the stealing whitelist. I didn't wan tto add on top of this function and decided to go for a refactoring instead

s: set = self._unknown_durations.pop(ts._prefix._name, set())
tts: TaskState
for tts in s:
if tts._processing_on:
self.set_duration_estimate(tts, tts._processing_on)
steal = self.extensions.get("stealing")
if steal:
steal.put_key_in_stealable(tts)

############################
# Update State Information #
Expand Down Expand Up @@ -3314,10 +3315,14 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double:
return nbytes / self._bandwidth

@ccall
def get_task_duration(self, ts: TaskState, default: double = -1) -> double:
"""
Get the estimated computation cost of the given task
(not including any communication cost).
def get_task_duration(self, ts: TaskState) -> double:
"""Get the estimated computation cost of the given task (not including
any communication cost).

If no data has been observed, value of
`distributed.scheduler.default-task-durations` are used. If none is set
for this task, `distributed.scheduler.unknown-task-duration` is used
instead.
"""
duration: double = ts._prefix._duration_average
if duration >= 0:
Expand All @@ -3327,7 +3332,7 @@ def get_task_duration(self, ts: TaskState, default: double = -1) -> double:
if s is None:
self._unknown_durations[ts._prefix._name] = s = set()
s.add(ts)
return default if default >= 0 else self.UNKNOWN_TASK_DURATION
return self.UNKNOWN_TASK_DURATION

@ccall
@exceptval(check=False)
Expand Down Expand Up @@ -7586,7 +7591,6 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0):
try:
if self.status == Status.closed:
return

last = time()
next_time = timedelta(seconds=0.1)

Expand Down Expand Up @@ -7713,6 +7717,8 @@ def _remove_from_processing(
) -> str: # -> str | None
"""
Remove *ts* from the set of processing tasks.

See also ``Scheduler.set_duration_estimate``
"""
ws: WorkerState = ts._processing_on
ts._processing_on = None # type: ignore
Expand Down Expand Up @@ -7870,6 +7876,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
ws: WorkerState
dts: TaskState

# FIXME: The duration attribute is not used on worker. We could safe ourselves the time to compute and submit this
if duration < 0:
duration = state.get_task_duration(ts)

Expand Down Expand Up @@ -7942,27 +7949,19 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:
@exceptval(check=False)
def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
"""See reevaluate_occupancy"""
old: double = ws._occupancy
new: double = 0
diff: double
ts: TaskState
est: double
old = ws._occupancy
for ts in ws._processing:
est = state.set_duration_estimate(ts, ws)
new += est
state.set_duration_estimate(ts, ws)

ws._occupancy = new
diff = new - old
state._total_occupancy += diff
state.check_idle_saturated(ws)

# significant increase in duration
if new > old * 1.3:
steal = state._extensions.get("stealing")
if steal is not None:
for ts in ws._processing:
steal.remove_key_from_stealable(ts)
steal.put_key_in_stealable(ts)
steal = state.extensions.get("stealing")
if not steal:
return
if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also cases where the occupancy actually significantly reduced but we might still want to reapply stealing. Until #5379 is fixed, this might increase the likelihood of that deadlock but we should allow deviations in both directions

for ts in ws._processing:
steal.remove_key_from_stealable(ts)
steal.put_key_in_stealable(ts)


@cfunc
Expand Down
8 changes: 4 additions & 4 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def steal_time_ratio(self, ts):
For example a result of zero implies a task without dependencies.
level: The location within a stealable list to place this value
"""
if not ts.dependencies: # no dependencies fast path
return 0, 0

split = ts.prefix.name
if split in fast_tasks:
if split in fast_tasks or split in self.scheduler.unknown_durations:
return None, None

if not ts.dependencies: # no dependencies fast path
return 0, 0

ws = ts.processing_on
compute_time = ws.processing[ts]
if compute_time < 0.005: # 5ms, just give up
Expand Down
20 changes: 16 additions & 4 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,24 @@ async def test_worksteal_many_thieves(c, s, *workers):
assert sum(map(len, s.has_what.values())) < 150


@pytest.mark.flaky(reruns=10, reruns_delay=5, reason="GH#3574")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_dont_steal_unknown_functions(c, s, a, b):
futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True)
await wait(futures)
assert len(a.data) >= 95, [len(a.data), len(b.data)]


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 2,
config={"distributed.scheduler.work-stealing-interval": "10ms"},
)
async def test_eventually_steal_unknown_functions(c, s, a, b):
futures = c.map(
slowinc, range(10), delay=0.1, workers=a.address, allow_other_workers=True
)
await wait(futures)
assert not s.unknown_durations
assert len(a.data) >= 3, [len(a.data), len(b.data)]
assert len(b.data) >= 3, [len(a.data), len(b.data)]

Expand Down Expand Up @@ -597,11 +601,15 @@ async def test(*args, **kwargs):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2, Worker=Nanny, timeout=60)
async def test_restart(c, s, a, b):
futures = c.map(
slowinc, range(100), delay=0.1, workers=a.address, allow_other_workers=True
slowinc, range(100), delay=0.01, workers=a.address, allow_other_workers=True
)
while not s.processing[b.worker_address]:
await asyncio.sleep(0.01)

# Unknown tasks are never stolen therefore wait for a measurement
while not any(s.tasks[f.key].state == "memory" for f in futures):
await asyncio.sleep(0.01)

steal = s.extensions["stealing"]
assert any(st for st in steal.stealable_all)
assert any(x for L in steal.stealable.values() for x in L)
Expand Down Expand Up @@ -821,9 +829,13 @@ async def test_balance_with_longer_task(c, s, a, b):
slowinc, 1, delay=5, workers=[a.address], priority=1
) # a surprisingly long task
z = c.submit(
inc, x, workers=[a.address], allow_other_workers=True, priority=0
slowadd, x, 1, workers=[a.address], allow_other_workers=True, priority=0
) # a task after y, suggesting a, but open to b

# Allow task to be learned, otherwise it will not be stolen
_ = c.submit(slowadd, x, 2, workers=[b.address])
await z
assert not y.done()
assert z.key in b.data


Expand Down