From aca9a5e75ceb0c66faa6dc6db2f6bfd40ab2e8ee Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 12 Jan 2023 11:18:43 -0600 Subject: [PATCH] Ensure client session is quiet after `cluster.close()` or `client.shutdown()` (#7429) --- distributed/client.py | 102 ++++++++++++++++++------------- distributed/tests/test_client.py | 32 +++++++++- 2 files changed, 90 insertions(+), 44 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index f781960018..813ec3384f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -19,7 +19,7 @@ from collections.abc import Collection, Iterator from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import DoneAndNotDoneFutures -from contextlib import contextmanager, suppress +from contextlib import asynccontextmanager, contextmanager, suppress from contextvars import ContextVar from functools import partial from importlib.metadata import PackageNotFoundError, version @@ -1441,7 +1441,10 @@ def wait_for_workers( return self.sync(self._wait_for_workers, n_workers, timeout=timeout) def _heartbeat(self): - if self.scheduler_comm: + # Don't send heartbeat if scheduler comm or cluster are already closed + if (self.scheduler_comm and not self.scheduler_comm.comm.closed()) or ( + self.cluster and self.cluster.status not in (Status.closed, Status.closing) + ): self.scheduler_comm.send({"op": "heartbeat-client"}) def __enter__(self): @@ -1505,6 +1508,14 @@ async def _handle_report(self): if is_python_shutting_down(): return if self.status == "running": + if self.cluster and self.cluster.status in ( + Status.closed, + Status.closing, + ): + # Don't attempt to reconnect if cluster are already closed. + # Instead close down the client. + await self._close() + return logger.info("Client report stream closed to scheduler") logger.info("Reconnecting...") self.status = "connecting" @@ -1538,7 +1549,7 @@ async def _handle_report(self): logger.exception(e) if breakout: break - except CancelledError: + except (CancelledError, asyncio.CancelledError): pass def _handle_key_in_memory(self, key=None, type=None, workers=None): @@ -1588,6 +1599,25 @@ def _handle_error(self, exception=None): logger.warning("Scheduler exception:") logger.exception(exception) + @asynccontextmanager + async def _wait_for_handle_report_task(self, fast=False): + current_task = asyncio.current_task() + handle_report_task = self._handle_report_task + # Give the scheduler 'stream-closed' message 100ms to come through + # This makes the shutdown slightly smoother and quieter + should_wait = ( + handle_report_task is not None and handle_report_task is not current_task + ) + if should_wait: + with suppress(asyncio.CancelledError, TimeoutError): + await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1) + + yield + + if should_wait: + with suppress(TimeoutError, asyncio.CancelledError): + await asyncio.wait_for(handle_report_task, 0 if fast else 2) + async def _close(self, fast=False): """ Send close signal and wait until scheduler completes @@ -1627,45 +1657,27 @@ async def _close(self, fast=False): ): self._send_to_scheduler({"op": "close-client"}) self._send_to_scheduler({"op": "close-stream"}) + async with self._wait_for_handle_report_task(fast=fast): + if ( + self.scheduler_comm + and self.scheduler_comm.comm + and not self.scheduler_comm.comm.closed() + ): + await self.scheduler_comm.close() - current_task = asyncio.current_task() - handle_report_task = self._handle_report_task - # Give the scheduler 'stream-closed' message 100ms to come through - # This makes the shutdown slightly smoother and quieter - if ( - handle_report_task is not None - and handle_report_task is not current_task - ): - with suppress(asyncio.CancelledError, TimeoutError): - await asyncio.wait_for(asyncio.shield(handle_report_task), 0.1) - - if ( - self.scheduler_comm - and self.scheduler_comm.comm - and not self.scheduler_comm.comm.closed() - ): - await self.scheduler_comm.close() - - for key in list(self.futures): - self._release_key(key=key) + for key in list(self.futures): + self._release_key(key=key) - if self._start_arg is None: - with suppress(AttributeError): - await self.cluster.close() + if self._start_arg is None: + with suppress(AttributeError): + await self.cluster.close() - await self.rpc.close() + await self.rpc.close() - self.status = "closed" + self.status = "closed" - if _get_global_client() is self: - _set_global_client(None) - - if ( - handle_report_task is not None - and handle_report_task is not current_task - ): - with suppress(TimeoutError, asyncio.CancelledError): - await asyncio.wait_for(handle_report_task, 0 if fast else 2) + if _get_global_client() is self: + _set_global_client(None) with suppress(AttributeError): await self.scheduler.close_rpc() @@ -1732,14 +1744,18 @@ async def _(): async def _shutdown(self): logger.info("Shutting down scheduler from Client") + self.status = "closing" for pc in self._periodic_callbacks.values(): pc.stop() - if self.cluster: - await self.cluster.close() - else: - with suppress(CommClosedError): - self.status = "closing" - await self.scheduler.terminate() + + async with self._wait_for_handle_report_task(): + if self.cluster: + await self.cluster.close() + else: + with suppress(CommClosedError): + await self.scheduler.terminate() + + await self._close() def shutdown(self): """Shut down the connected scheduler and workers diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 57b070a1e2..529e884de7 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6297,6 +6297,7 @@ async def test_shutdown(): assert s.status == Status.closed assert w.status in {Status.closed, Status.closing} + assert c.status == "closed" @gen_test() @@ -6308,10 +6309,12 @@ async def test_shutdown_localcluster(): await c.shutdown() assert lc.scheduler.status == Status.closed + assert lc.status == Status.closed + assert c.status == "closed" @gen_test() -async def test_shutdown_is_clean(): +async def test_shutdown_stops_callbacks(): async with Scheduler(dashboard_address=":0") as s: async with Worker(s.address) as w: async with Client(s.address, asynchronous=True) as c: @@ -6319,6 +6322,33 @@ async def test_shutdown_is_clean(): assert not any(pc.is_running() for pc in c._periodic_callbacks.values()) +@gen_test() +async def test_shutdown_is_quiet_with_cluster(): + async with LocalCluster( + n_workers=1, asynchronous=True, processes=False, dashboard_address=":0" + ) as cluster: + with captured_logger(logging.getLogger("distributed.client")) as logger: + timeout = 0.1 + async with Client(cluster, asynchronous=True, timeout=timeout) as c: + await c.shutdown() + await asyncio.sleep(timeout) + msg = logger.getvalue().strip() + assert msg == "Shutting down scheduler from Client", msg + + +@gen_test() +async def test_client_is_quiet_cluster_close(): + async with LocalCluster( + n_workers=1, asynchronous=True, processes=False, dashboard_address=":0" + ) as cluster: + with captured_logger(logging.getLogger("distributed.client")) as logger: + timeout = 0.1 + async with Client(cluster, asynchronous=True, timeout=timeout) as c: + await cluster.close() + await asyncio.sleep(timeout) + assert not logger.getvalue().strip() + + @gen_test() async def test_config_inherited_by_subprocess(): with dask.config.set(foo=100):