Skip to content

Commit

Permalink
Remove declassification of fast keys in steal_time_ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Sep 8, 2022
1 parent b133009 commit b613fe9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
4 changes: 2 additions & 2 deletions distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,7 +1783,7 @@ def __init__(self, scheduler, **kwargs):
self.last = 0
self.source = ColumnDataSource(
{
"time": [time() - 20, time()],
"time": [time() - 60, time()],
"level": [0, 15],
"color": ["white", "white"],
"duration": [0, 0],
Expand Down Expand Up @@ -1828,7 +1828,7 @@ def convert(self, msgs):
"""Convert a log message to a glyph"""
total_duration = 0
for msg in msgs:
time, level, key, duration, sat, occ_sat, idl, occ_idl = msg
time, level, key, duration, sat, occ_sat, idl, occ_idl = msg[:8]
total_duration += duration

try:
Expand Down
6 changes: 2 additions & 4 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,16 @@ def steal_time_ratio(self, ts):

ws = ts.processing_on
compute_time = ws.processing[ts]
if compute_time < 0.005: # 5ms, just give up
return None, None

nbytes = ts.get_nbytes_deps()
transfer_time = nbytes / self.scheduler.bandwidth + LATENCY
cost_multiplier = transfer_time / compute_time
if cost_multiplier > 100:
return None, None

level = int(round(log2(cost_multiplier) + 6))
if level < 1:
level = 1
elif level > len(self.cost_multipliers):
return None, None

return cost_multiplier, level

Expand Down
35 changes: 35 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,3 +1341,38 @@ def test_steal_worker_state(ws_with_running_task):
assert "x" not in ws.tasks
assert "x" not in ws.data
assert ws.available_resources == {"R": 1}


@pytest.mark.slow()
@gen_cluster(nthreads=[("", 1)] * 4, client=True)
async def test_steal_very_fast_tasks(c, s, *workers):
# Ensure that very fast tasks
root = dask.delayed(lambda n: "x" * n)(
dask.utils.parse_bytes("1MiB"), dask_key_name="root"
)

@dask.delayed
def func(*args):
import time

time.sleep(0.002)

ntasks = 1000
results = [func(root, i) for i in range(ntasks)]
futs = c.compute(results)
await c.gather(futs)

dat = {}
max_ = 0
rest = 0
for w in workers:
ntasks = len(w.data)
if ntasks > max_:
rest += max_
max_ = ntasks
else:
rest += ntasks
dat[w] = len(w.data)
assert ntasks > ntasks / len(workers) * 0.5

assert max_ < rest * 2 / 3

0 comments on commit b613fe9

Please sign in to comment.