Skip to content

Commit

Permalink
Ensure client session is quiet after cluster.close() or `client.shu…
Browse files Browse the repository at this point in the history
…tdown()` (#7429)
  • Loading branch information
jrbourbeau authored Jan 12, 2023
1 parent 6d3182e commit aca9a5e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 44 deletions.
102 changes: 59 additions & 43 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Collection, Iterator
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import contextmanager, suppress
from contextlib import asynccontextmanager, contextmanager, suppress
from contextvars import ContextVar
from functools import partial
from importlib.metadata import PackageNotFoundError, version
Expand Down Expand Up @@ -1441,7 +1441,10 @@ def wait_for_workers(
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)

def _heartbeat(self):
if self.scheduler_comm:
# Don't send heartbeat if scheduler comm or cluster are already closed
if (self.scheduler_comm and not self.scheduler_comm.comm.closed()) or (
self.cluster and self.cluster.status not in (Status.closed, Status.closing)
):
self.scheduler_comm.send({"op": "heartbeat-client"})

def __enter__(self):
Expand Down Expand Up @@ -1505,6 +1508,14 @@ async def _handle_report(self):
if is_python_shutting_down():
return
if self.status == "running":
if self.cluster and self.cluster.status in (
Status.closed,
Status.closing,
):
# Don't attempt to reconnect if cluster are already closed.
# Instead close down the client.
await self._close()
return
logger.info("Client report stream closed to scheduler")
logger.info("Reconnecting...")
self.status = "connecting"
Expand Down Expand Up @@ -1538,7 +1549,7 @@ async def _handle_report(self):
logger.exception(e)
if breakout:
break
except CancelledError:
except (CancelledError, asyncio.CancelledError):
pass

def _handle_key_in_memory(self, key=None, type=None, workers=None):
Expand Down Expand Up @@ -1588,6 +1599,25 @@ def _handle_error(self, exception=None):
logger.warning("Scheduler exception:")
logger.exception(exception)

@asynccontextmanager
async def _wait_for_handle_report_task(self, fast=False):
current_task = asyncio.current_task()
handle_report_task = self._handle_report_task
# Give the scheduler 'stream-closed' message 100ms to come through
# This makes the shutdown slightly smoother and quieter
should_wait = (
handle_report_task is not None and handle_report_task is not current_task
)
if should_wait:
with suppress(asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)

yield

if should_wait:
with suppress(TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(handle_report_task, 0 if fast else 2)

async def _close(self, fast=False):
"""
Send close signal and wait until scheduler completes
Expand Down Expand Up @@ -1627,45 +1657,27 @@ async def _close(self, fast=False):
):
self._send_to_scheduler({"op": "close-client"})
self._send_to_scheduler({"op": "close-stream"})
async with self._wait_for_handle_report_task(fast=fast):
if (
self.scheduler_comm
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
await self.scheduler_comm.close()

current_task = asyncio.current_task()
handle_report_task = self._handle_report_task
# Give the scheduler 'stream-closed' message 100ms to come through
# This makes the shutdown slightly smoother and quieter
if (
handle_report_task is not None
and handle_report_task is not current_task
):
with suppress(asyncio.CancelledError, TimeoutError):
await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1)

if (
self.scheduler_comm
and self.scheduler_comm.comm
and not self.scheduler_comm.comm.closed()
):
await self.scheduler_comm.close()

for key in list(self.futures):
self._release_key(key=key)
for key in list(self.futures):
self._release_key(key=key)

if self._start_arg is None:
with suppress(AttributeError):
await self.cluster.close()
if self._start_arg is None:
with suppress(AttributeError):
await self.cluster.close()

await self.rpc.close()
await self.rpc.close()

self.status = "closed"
self.status = "closed"

if _get_global_client() is self:
_set_global_client(None)

if (
handle_report_task is not None
and handle_report_task is not current_task
):
with suppress(TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(handle_report_task, 0 if fast else 2)
if _get_global_client() is self:
_set_global_client(None)

with suppress(AttributeError):
await self.scheduler.close_rpc()
Expand Down Expand Up @@ -1732,14 +1744,18 @@ async def _():
async def _shutdown(self):
logger.info("Shutting down scheduler from Client")

self.status = "closing"
for pc in self._periodic_callbacks.values():
pc.stop()
if self.cluster:
await self.cluster.close()
else:
with suppress(CommClosedError):
self.status = "closing"
await self.scheduler.terminate()

async with self._wait_for_handle_report_task():
if self.cluster:
await self.cluster.close()
else:
with suppress(CommClosedError):
await self.scheduler.terminate()

await self._close()

def shutdown(self):
"""Shut down the connected scheduler and workers
Expand Down
32 changes: 31 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6297,6 +6297,7 @@ async def test_shutdown():

assert s.status == Status.closed
assert w.status in {Status.closed, Status.closing}
assert c.status == "closed"


@gen_test()
Expand All @@ -6308,17 +6309,46 @@ async def test_shutdown_localcluster():
await c.shutdown()

assert lc.scheduler.status == Status.closed
assert lc.status == Status.closed
assert c.status == "closed"


@gen_test()
async def test_shutdown_is_clean():
async def test_shutdown_stops_callbacks():
async with Scheduler(dashboard_address=":0") as s:
async with Worker(s.address) as w:
async with Client(s.address, asynchronous=True) as c:
await c.shutdown()
assert not any(pc.is_running() for pc in c._periodic_callbacks.values())


@gen_test()
async def test_shutdown_is_quiet_with_cluster():
async with LocalCluster(
n_workers=1, asynchronous=True, processes=False, dashboard_address=":0"
) as cluster:
with captured_logger(logging.getLogger("distributed.client")) as logger:
timeout = 0.1
async with Client(cluster, asynchronous=True, timeout=timeout) as c:
await c.shutdown()
await asyncio.sleep(timeout)
msg = logger.getvalue().strip()
assert msg == "Shutting down scheduler from Client", msg


@gen_test()
async def test_client_is_quiet_cluster_close():
async with LocalCluster(
n_workers=1, asynchronous=True, processes=False, dashboard_address=":0"
) as cluster:
with captured_logger(logging.getLogger("distributed.client")) as logger:
timeout = 0.1
async with Client(cluster, asynchronous=True, timeout=timeout) as c:
await cluster.close()
await asyncio.sleep(timeout)
assert not logger.getvalue().strip()


@gen_test()
async def test_config_inherited_by_subprocess():
with dask.config.set(foo=100):
Expand Down

0 comments on commit aca9a5e

Please sign in to comment.