Skip to content

Commit

Permalink
Refactor some tests (#8908)
Browse files Browse the repository at this point in the history
* Skip test_deadlock_cancelled_after_inflight_before_gather_from_worker on windows

* Refactor some tests
  • Loading branch information
fjetter authored Oct 25, 2024
1 parent 27ed3d2 commit 1205a70
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
11 changes: 9 additions & 2 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,23 @@ async def test_closed_worker_during_transfer(c, s, a, b):
config={"distributed.scheduler.allowed-failures": 0},
)
async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b):
await c.register_plugin(BlockedShuffleReceiveShuffleWorkerPlugin(), name="shuffle")
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-03-01",
end="2000-02-01",
dtypes={"x": float, "y": float},
freq="10 s",
)
shuffle_extA = a.plugins["shuffle"]
shuffle_extB = b.plugins["shuffle"]
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
out = df.shuffle("x")
out = c.compute(out.x.size)
await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b)
await asyncio.gather(
shuffle_extA.in_shuffle_receive.wait(), shuffle_extB.in_shuffle_receive.wait()
)
shuffle_extA.block_shuffle_receive.set()
shuffle_extB.block_shuffle_receive.set()
await assert_worker_cleanup(b, close=True)

with pytest.raises(KilledWorker):
Expand Down
38 changes: 33 additions & 5 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
SecedeEvent,
TaskFinishedMsg,
UpdateDataEvent,
WorkerState,
)


Expand Down Expand Up @@ -825,8 +826,11 @@ async def release_all_futures():


@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"])
@pytest.mark.parametrize("close_worker", [False, True])
@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"})
@pytest.mark.parametrize("close_worker", [True])
@gen_cluster(
client=True,
config={"distributed.comm.timeouts.connect": "500ms"},
)
async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
c, s, a, x, intermediate_state, close_worker
):
Expand All @@ -839,10 +843,34 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2")
await fut2

async with BlockedGatherDep(s.address, name="b") as b:
class InstrumentedWorkerState(WorkerState):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fut2_in_flight = asyncio.Event()
self.fut2_in_intermediate = asyncio.Event()

def _transition(self, ts, finish, *args, **kwargs):
def _verify_state(finish):
if ts.key == fut2.key:
if isinstance(finish, tuple) and finish[0] == "flight":
self.fut2_in_flight.set()
if self.fut2_in_flight.is_set() and finish == intermediate_state:
self.fut2_in_intermediate.set()

# The expected state might be either the requested one or the
# actual, final state
_verify_state(finish)
try:
return super()._transition(ts, finish, *args, **kwargs)
finally:
_verify_state(ts.state)

async with BlockedGatherDep(
s.address, name="b", WorkerStateClass=InstrumentedWorkerState
) as b:
fut3 = c.submit(inc, fut2, workers=[b.address], key="f3")

await wait_for_state(fut2.key, "flight", b)
await b.state.fut2_in_flight.wait()

s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address})

Expand All @@ -855,7 +883,7 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
stimulus_id="remove-worker",
)

await wait_for_state(fut2.key, intermediate_state, b, interval=0)
await b.state.fut2_in_intermediate.wait()

b.block_gather_dep.set()
await fut3
Expand Down
3 changes: 2 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ def __init__(
###################################
# Parameters to Server
scheduler_sni: str | None = None,
WorkerStateClass: type = WorkerState,
**kwargs,
):
if reconnect is not None:
Expand Down Expand Up @@ -788,7 +789,7 @@ def __init__(
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerState(
state = WorkerStateClass(
nthreads=nthreads,
data=self.memory_manager.data,
threads=self.threads,
Expand Down

0 comments on commit 1205a70

Please sign in to comment.