diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index d89483472da..cd38f364c8f 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1783,7 +1783,7 @@ def __init__(self, scheduler, **kwargs): self.last = 0 self.source = ColumnDataSource( { - "time": [time() - 20, time()], + "time": [time() - 60, time()], "level": [0, 15], "color": ["white", "white"], "duration": [0, 0], @@ -1828,7 +1828,7 @@ def convert(self, msgs): """Convert a log message to a glyph""" total_duration = 0 for msg in msgs: - time, level, key, duration, sat, occ_sat, idl, occ_idl = msg + time, level, key, duration, sat, occ_sat, idl, occ_idl = msg[:8] total_duration += duration try: diff --git a/distributed/stealing.py b/distributed/stealing.py index b9858f4656a..cedd08cba61 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -223,18 +223,16 @@ def steal_time_ratio(self, ts): ws = ts.processing_on compute_time = ws.processing[ts] - if compute_time < 0.005: # 5ms, just give up - return None, None nbytes = ts.get_nbytes_deps() transfer_time = nbytes / self.scheduler.bandwidth + LATENCY cost_multiplier = transfer_time / compute_time - if cost_multiplier > 100: - return None, None level = int(round(log2(cost_multiplier) + 6)) if level < 1: level = 1 + elif level > len(self.cost_multipliers): + return None, None return cost_multiplier, level diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 3128f902b8c..f3b4a1a3d05 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1341,3 +1341,38 @@ def test_steal_worker_state(ws_with_running_task): assert "x" not in ws.tasks assert "x" not in ws.data assert ws.available_resources == {"R": 1} + + +@pytest.mark.slow() +@gen_cluster(nthreads=[("", 1)] * 4, client=True) +async def test_steal_very_fast_tasks(c, s, *workers): + # Ensure that very fast tasks + root = dask.delayed(lambda n: "x" * n)( + dask.utils.parse_bytes("1MiB"), dask_key_name="root" + ) + + @dask.delayed + def func(*args): + import time + + time.sleep(0.002) + + ntasks = 1000 + results = [func(root, i) for i in range(ntasks)] + futs = c.compute(results) + await c.gather(futs) + + dat = {} + max_ = 0 + rest = 0 + for w in workers: + ntasks = len(w.data) + if ntasks > max_: + rest += max_ + max_ = ntasks + else: + rest += ntasks + dat[w] = len(w.data) + assert ntasks > ntasks / len(workers) * 0.5 + + assert max_ < rest * 2 / 3