Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer double count transfer cost in stealing #7026

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 77 additions & 138 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
from time import time
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast

import sortedcontainers
from tlz import topk
from tornado.ioloop import PeriodicCallback

import dask
from dask.utils import parse_timedelta

from distributed.comm.addressing import get_address_host
from distributed.core import CommClosedError, Status
from distributed.core import CommClosedError
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.utils import log_errors, recursive_to_dict

if TYPE_CHECKING:
# Recursive imports
from distributed.scheduler import Scheduler, TaskState, WorkerState
from distributed.scheduler import Scheduler, SchedulerState, TaskState, WorkerState

# Stealing requires multiple network bounces and if successful also task
# submission which may include code serialization. Therefore, be very
Expand Down Expand Up @@ -234,14 +232,13 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non
if not ts.dependencies: # no dependencies fast path
return 0, 0

assert ts.processing_on
ws = ts.processing_on
compute_time = ws.processing[ts]
compute_time = self.scheduler.get_task_duration(ts)

if not compute_time:
# occupancy/ws.proccessing[ts] is only allowed to be zero for
# long running tasks which cannot be stolen
assert ts in ws.long_running
assert ts.processing_on
assert ts in ts.processing_on.long_running
return None, None

nbytes = ts.get_nbytes_deps()
Expand Down Expand Up @@ -403,116 +400,90 @@ def balance(self) -> None:
def combined_occupancy(ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_occupancy[ws]

def maybe_move_task(
level: int,
ts: TaskState,
victim: WorkerState,
thief: WorkerState,
duration: float,
cost_multiplier: float,
) -> None:
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)

if occ_thief + cost_multiplier * duration <= occ_victim - duration / 2:
self.move_task_request(ts, victim, thief)
log.append(
(
start,
level,
ts.key,
duration,
victim.address,
occ_victim,
thief.address,
occ_thief,
)
)
s.check_idle_saturated(victim, occ=occ_victim)
s.check_idle_saturated(thief, occ=occ_thief)

with log_errors():
i = 0
# Paused and closing workers must never become thieves
idle = [ws for ws in s.idle.values() if ws.status == Status.running]
if not idle or len(idle) == len(s.workers):
potential_thieves = set(s.idle.values())
if not potential_thieves or len(potential_thieves) == len(s.workers):
return

victim: WorkerState | None
saturated: set[WorkerState] | list[WorkerState] = s.saturated
if not saturated:
saturated = topk(10, s.workers.values(), key=combined_occupancy)
saturated = [
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
if not potential_victims:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, len(potential_thievens) == len(s.workers) and if not potential_victims should check the same thing, i.e., if everybody is idle then nobody is saturated. I suggest dropping the former check for simplicity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this comment is on the correct line but potential_thieves is basically a view on Scheduler.idle which is not the same as s.workers it's a filtered subset, see

idle = self.idle
saturated = self.saturated
if (
(p < nc or occ < nc * avg / 2)
if math.isinf(self.WORKER_SATURATION)
else not _worker_full(ws, self.WORKER_SATURATION)
):
if ws.status == Status.running:
idle[ws.address] = ws
saturated.discard(ws)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I misread 412 as a stop condition.

potential_victims = topk(10, s.workers.values(), key=combined_occupancy)
potential_victims = [
ws
for ws in saturated
if combined_occupancy(ws) > 0.2 and len(ws.processing) > ws.nthreads
for ws in potential_victims
if combined_occupancy(ws) > 0.2
and len(ws.processing) > ws.nthreads
and ws not in potential_thieves
]
elif len(saturated) < 20:
saturated = sorted(saturated, key=combined_occupancy, reverse=True)
if len(idle) < 20:
idle = sorted(idle, key=combined_occupancy)

for level, cost_multiplier in enumerate(self.cost_multipliers):
if not idle:
if len(potential_victims) < 20:
potential_victims = sorted(
potential_victims, key=combined_occupancy, reverse=True
)
avg_occ_per_threads = (
self.scheduler.total_occupancy / self.scheduler.total_nthreads
)
for level, _ in enumerate(self.cost_multipliers):
if not potential_thieves:
break
for victim in list(saturated):
for victim in list(potential_victims):

stealable = self.stealable[victim.address][level]
if not stealable or not idle:
if not stealable or not potential_thieves:
continue

for ts in list(stealable):
if not potential_thieves:
break
if (
ts not in self.key_stealable
or ts.processing_on is not victim
):
stealable.discard(ts)
continue
i += 1
if not idle:
break

thieves = _potential_thieves_for(ts, idle)
if not thieves:
break
thief = thieves[i % len(thieves)]

duration = victim.processing.get(ts)
if duration is None:
stealable.discard(ts)
if not (thief := _pop_thief(s, ts, potential_thieves)):
continue

maybe_move_task(
level, ts, victim, thief, duration, cost_multiplier
)

if self.cost_multipliers[level] < 20: # don't steal from public at cost
stealable = self.stealable_all[level]
for ts in list(stealable):
if not idle:
break
if ts not in self.key_stealable:
task_occ_on_victim = victim.processing.get(ts)
if task_occ_on_victim is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should need to add the thief back in

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easier to use something like _pick_thief again and only worry about removing the thief at the correct points as opposed to first removing it and then worrying about adding it back in.

stealable.discard(ts)
continue

victim = ts.processing_on
if victim is None:
stealable.discard(ts)
continue
if combined_occupancy(victim) < 0.2:
continue
if len(victim.processing) <= victim.nthreads:
continue
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)
comm_cost = self.scheduler.get_comm_cost(ts, thief)
compute = self.scheduler.get_task_duration(ts)

i += 1
thieves = _potential_thieves_for(ts, idle)
if not thieves:
continue
thief = thieves[i % len(thieves)]
duration = victim.processing[ts]

maybe_move_task(
level, ts, victim, thief, duration, cost_multiplier
)
if (
occ_thief + comm_cost + compute
<= occ_victim - task_occ_on_victim / 2
):
self.move_task_request(ts, victim, thief)
log.append(
(
start,
level,
ts.key,
task_occ_on_victim,
victim.address,
occ_victim,
thief.address,
occ_thief,
)
)

occ_thief = combined_occupancy(thief)
# TODO: this is replicating some logic of
# check_idle_saturated
if occ_thief >= thief.nthreads * avg_occ_per_threads / 2:
potential_thieves.add(thief)
else:
potential_thieves.add(thief)
self.scheduler.check_idle_saturated(
victim, occ=combined_occupancy(victim)
)

if log:
self.log(log)
Expand Down Expand Up @@ -542,51 +513,19 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
return out


def _potential_thieves_for(
ts: TaskState,
idle: sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState],
) -> sortedcontainers.SortedValuesView[WorkerState] | list[WorkerState]:
"""Return the list of workers from ``idle`` that could steal ``ts``."""
if _has_restrictions(ts):
return [ws for ws in idle if _can_steal(ws, ts)]
else:
return idle


def _can_steal(thief: WorkerState, ts: TaskState) -> bool:
"""Determine whether worker ``thief`` can steal task ``ts``.

Assumes that `ts` has some restrictions.
"""
if (
ts.host_restrictions
and get_address_host(thief.address) not in ts.host_restrictions
):
return False
elif ts.worker_restrictions and thief.address not in ts.worker_restrictions:
return False

if not ts.resource_restrictions:
return True

for resource, value in ts.resource_restrictions.items():
try:
supplied = thief.resources[resource]
except KeyError:
return False
else:
if supplied < value:
return False
return True


def _has_restrictions(ts: TaskState) -> bool:
"""Determine whether the given task has restrictions and whether these
restrictions are strict.
"""
return not ts.loose_restrictions and bool(
ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions
)
Comment on lines -545 to -589
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm missing something, but have we completely dropped checking for restrictions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was duplicated code. I'm reusing Scheduler.valid_workers

def _pop_thief(
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers:
subset = potential_thieves & valid_workers
if subset:
thief = subset.pop()
potential_thieves.discard(thief)
return thief
elif not ts.loose_restrictions:
return None
return potential_thieves.pop()


fast_tasks = {"split-shuffle"}
Loading