Skip to content

Commit

Permalink
Remove stealable_all
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Sep 12, 2022
1 parent a69b602 commit c11b38c
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 47 deletions.
53 changes: 31 additions & 22 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class InFlightInfo(TypedDict):

class WorkStealing(SchedulerPlugin):
scheduler: Scheduler
# ({ task states for level 0}, ..., {task states for level 14})
stealable_all: tuple[set[TaskState], ...]
# {worker: ({ task states for level 0}, ..., {task states for level 14})}
stealable: dict[str, tuple[set[TaskState], ...]]
# { task state: (worker, level) }
Expand All @@ -78,12 +76,12 @@ class WorkStealing(SchedulerPlugin):
in_flight: dict[TaskState, InFlightInfo]
# { worker state: occupancy }
in_flight_occupancy: defaultdict[WorkerState, float]
in_flight_tasks: defaultdict[WorkerState, int]
_in_flight_event: asyncio.Event
_request_counter: int

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.stealable_all = tuple(set() for _ in range(15))
self.stealable = {}
self.key_stealable = {}

Expand All @@ -103,6 +101,7 @@ def __init__(self, scheduler: Scheduler):
self.count = 0
self.in_flight = {}
self.in_flight_occupancy = defaultdict(lambda: 0)
self.in_flight_tasks = defaultdict(lambda: 0)
self._in_flight_event = asyncio.Event()
self._request_counter = 0
self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm
Expand Down Expand Up @@ -181,6 +180,8 @@ def transition(
victim = d["victim"]
self.in_flight_occupancy[thief] -= d["thief_duration"]
self.in_flight_occupancy[victim] += d["victim_duration"]
self.in_flight_tasks[victim] += 1
self.in_flight_tasks[thief] -= 1
if not self.in_flight:
self.in_flight_occupancy.clear()
self._in_flight_event.set()
Expand All @@ -197,7 +198,6 @@ def put_key_in_stealable(self, ts: TaskState) -> None:
assert ts.processing_on
ws = ts.processing_on
worker = ws.address
self.stealable_all[level].add(ts)
self.stealable[worker][level].add(ts)
self.key_stealable[ts] = (worker, level)

Expand All @@ -211,10 +211,6 @@ def remove_key_from_stealable(self, ts: TaskState) -> None:
self.stealable[worker][level].remove(ts)
except KeyError:
pass
try:
self.stealable_all[level].remove(ts)
except KeyError:
pass

def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]:
"""The compute to communication time ratio of a key
Expand Down Expand Up @@ -295,6 +291,8 @@ def move_task_request(

self.in_flight_occupancy[victim] -= victim_duration
self.in_flight_occupancy[thief] += thief_duration
self.in_flight_tasks[victim] -= 1
self.in_flight_tasks[thief] += 1
return stimulus_id
except CommClosedError:
logger.info("Worker comm %r closed while stealing: %r", victim, ts)
Expand Down Expand Up @@ -400,13 +398,15 @@ def balance(self) -> None:
def combined_occupancy(ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_occupancy[ws]

def combined_nprocessing(ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_tasks[ws]

with log_errors():
i = 0
# Paused and closing workers must never become thieves
potential_thieves = set(s.idle.values())
if not potential_thieves or len(potential_thieves) == len(s.workers):
return

victim: WorkerState | None
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
if not potential_victims:
Expand All @@ -415,13 +415,23 @@ def combined_occupancy(ws: WorkerState) -> float:
ws
for ws in potential_victims
if combined_occupancy(ws) > 0.2
and len(ws.processing) > ws.nthreads
and combined_nprocessing(ws) > ws.nthreads
and ws not in potential_thieves
]
if not potential_victims:
# TODO: Unclear how to reach this and what the implications
# are. The return is only an optimization since the for-loop
# below would be a no op but we'd safe ourselves a few loop
# cycles. Unless any measurements about runtime, occupancy,
# etc. changes we'd not get out of this and may have an
# unbalanced cluster
return
if len(potential_victims) < 20:
potential_victims = sorted(
potential_victims, key=combined_occupancy, reverse=True
)
assert potential_victims
assert potential_thieves
avg_occ_per_threads = (
self.scheduler.total_occupancy / self.scheduler.total_nthreads
)
Expand All @@ -444,7 +454,7 @@ def combined_occupancy(ws: WorkerState) -> float:
stealable.discard(ts)
continue
i += 1
if not (thief := _pop_thief(s, ts, potential_thieves)):
if not (thief := _get_thief(s, ts, potential_thieves)):
continue
task_occ_on_victim = victim.processing.get(ts)
if task_occ_on_victim is None:
Expand Down Expand Up @@ -475,12 +485,15 @@ def combined_occupancy(ws: WorkerState) -> float:
)

occ_thief = combined_occupancy(thief)
p = len(thief.processing) + self.in_flight_tasks[thief]

nc = thief.nthreads
# TODO: this is replicating some logic of
# check_idle_saturated
if occ_thief >= thief.nthreads * avg_occ_per_threads / 2:
potential_thieves.add(thief)
else:
potential_thieves.add(thief)
# pending: float = occ_thief * (p - nc) / (p * nc)
if not (p < nc or occ_thief < nc * avg_occ_per_threads / 2):
potential_thieves.discard(thief)
stealable.discard(ts)
self.scheduler.check_idle_saturated(
victim, occ=combined_occupancy(victim)
)
Expand All @@ -497,8 +510,6 @@ def restart(self, scheduler: Any) -> None:
for s in stealable:
s.clear()

for s in self.stealable_all:
s.clear()
self.key_stealable.clear()

def story(self, *keys_or_ts: str | TaskState) -> list:
Expand All @@ -513,19 +524,17 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
return out


def _pop_thief(
def _get_thief(
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers:
subset = potential_thieves & valid_workers
if subset:
thief = subset.pop()
potential_thieves.discard(thief)
return thief
return next(iter(subset))
elif not ts.loose_restrictions:
return None
return potential_thieves.pop()
return next(iter(potential_thieves))


fast_tasks = {"split-shuffle"}
35 changes: 18 additions & 17 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,25 +850,26 @@ async def test_steal_twice(c, s, a, b):
await asyncio.sleep(0.01)

# Army of new workers arrives to help
workers = await asyncio.gather(*(Worker(s.address) for _ in range(20)))
async with contextlib.AsyncExitStack() as stack:
# This is pretty timing sensitive
workers = [stack.enter_async_context(Worker(s.address)) for _ in range(10)]
workers = await asyncio.gather(*workers)

await wait(futures)
await wait(futures)

# Note: this includes a and b
empty_workers = [ws for ws in s.workers.values() if not ws.has_what]
assert (
len(empty_workers) < 3
), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})"
# This also tests that some tasks were stolen from b
# (see `while len(b.state.tasks) < 30` above)
# If queuing is enabled, then there was nothing to steal from b,
# so this just tests the queue was balanced not-terribly.
assert max(len(ws.has_what) for ws in s.workers.values()) < 30

assert a.state.in_flight_tasks_count == 0
assert b.state.in_flight_tasks_count == 0

await asyncio.gather(*(w.close() for w in workers))
# Note: this includes a and b
empty_workers = [ws for ws in s.workers.values() if not ws.has_what]
assert (
len(empty_workers) < 3
), f"Too many workers without keys ({len(empty_workers)} out of {len(s.workers)})"
# This also tests that some tasks were stolen from b
# (see `while len(b.state.tasks) < 30` above)
# If queuing is enabled, then there was nothing to steal from b,
# so this just tests the queue was balanced not-terribly.
assert max(len(ws.has_what) for ws in s.workers.values()) < 30

assert a.state.in_flight_tasks_count == 0
assert b.state.in_flight_tasks_count == 0


@gen_cluster(
Expand Down
23 changes: 18 additions & 5 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,8 +1029,21 @@ def test_ws_with_running_task(ws_with_running_task):
assert ts.state in ("executing", "long-running")


@pytest.mark.parametrize("nbytes", [0, 1, 1234.567])
def test_sizeof(nbytes):
assert sizeof(SizeOf(nbytes)) == nbytes
assert isinstance(gen_nbytes(nbytes), SizeOf)
assert sizeof(gen_nbytes(nbytes)) == nbytes
def test_sizeof():
assert sizeof(SizeOf(100)) == 100
assert isinstance(gen_nbytes(100), SizeOf)
assert sizeof(gen_nbytes(100)) == 100


@pytest.mark.parametrize(
"input, exc, msg",
[
(12345.0, TypeError, "Expected integer"),
(-1, ValueError, "larger than"),
(0, ValueError, "larger than"),
(10, ValueError, "larger than"),
],
)
def test_sizeof_error(input, exc, msg):
with pytest.raises(exc, match=msg):
SizeOf(input)
13 changes: 10 additions & 3 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,13 +2466,20 @@ class SizeOf:
An object that returns exactly nbytes when inspected by dask.sizeof.sizeof
"""

def __init__(self, nbytes: float) -> None:
self._nbytes = nbytes - sizeof(object())
def __init__(self, nbytes: int) -> None:
if not isinstance(nbytes, int):
raise TypeError(f"Expected integer for nbytes but got {type(nbytes)}")
size_obj = sizeof(object())
if nbytes < size_obj:
raise ValueError(
f"Expected a value larger than {size_obj} integer but got {nbytes}."
)
self._nbytes = nbytes - size_obj

def __sizeof__(self) -> int:
return self._nbytes


def gen_nbytes(nbytes: float) -> SizeOf:
def gen_nbytes(nbytes: int) -> SizeOf:
"""A function that emulates exactly nbytes on the worker data structure."""
return SizeOf(nbytes)

0 comments on commit c11b38c

Please sign in to comment.