diff --git a/CHANGES/7188.feature b/CHANGES/7188.feature new file mode 100644 index 00000000000..777144aa0e2 --- /dev/null +++ b/CHANGES/7188.feature @@ -0,0 +1 @@ +Added a graceful shutdown period which allows pending tasks to complete before the application's cleanup is called. The period can be adjusted with the ``shutdown_timeout`` parameter -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index e0aacbb68fc..652a8d4713d 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -115,7 +115,7 @@ async def start_server(self, **kwargs: Any) -> None: if self.runner: return self._ssl = kwargs.pop("ssl", None) - self.runner = await self._make_runner(**kwargs) + self.runner = await self._make_runner(handler_cancellation=True, **kwargs) await self.runner.setup() if not self.port: self.port = 0 diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index cdb0fe58923..ce85eeb6a69 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -283,4 +283,4 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter try: return await self._sendfile(request, fobj, offset, count) finally: - await loop.run_in_executor(None, fobj.close) + await asyncio.shield(loop.run_in_executor(None, fobj.close)) diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 4b1408c31a6..8dad92bab0e 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,6 +2,7 @@ import signal import socket from abc import ABC, abstractmethod +from contextlib import suppress from typing import Any, List, Optional, Set, Type from yarl import URL @@ -80,11 +81,26 @@ async def stop(self) -> None: # named pipes do not have wait_closed property if hasattr(self._server, "wait_closed"): await self._server.wait_closed() + + # Wait for pending tasks for a given time limit. + with suppress(asyncio.TimeoutError): + await asyncio.wait_for( + self._wait(asyncio.current_task()), timeout=self._shutdown_timeout + ) + await self._runner.shutdown() assert self._runner.server await self._runner.server.shutdown(self._shutdown_timeout) self._runner._unreg_site(self) + async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None: + exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task} + # TODO(PY38): while tasks := asyncio.all_tasks() - exclude: + tasks = asyncio.all_tasks() - exclude + while tasks: + await asyncio.wait(tasks) + tasks = asyncio.all_tasks() - exclude + class TCPSite(BaseSite): __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") @@ -247,7 +263,7 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites") + __slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites") def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None: self._handle_signals = handle_signals @@ -287,6 +303,11 @@ async def setup(self) -> None: pass self._server = await self._make_server() + # On shutdown we want to avoid waiting on tasks which run forever. + # It's very likely that all tasks which run forever will have been created by + # the time we have completed the application startup (in self._make_server()), + # so we just record all running tasks here and exclude them later. + self.starting_tasks = asyncio.all_tasks() @abstractmethod async def shutdown(self) -> None: diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index bf71d11fce0..6055ddaf319 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -927,8 +927,14 @@ Graceful shutdown Stopping *aiohttp web server* by just closing all connections is not always satisfactory. -The problem is: if application supports :term:`websocket`\s or *data -streaming* it most likely has open connections at server +The first thing aiohttp will do is to stop listening on the sockets, +so new connections will be rejected. It will then wait a few +seconds to allow any pending tasks to complete before continuing +with application shutdown. The timeout can be adjusted with +``shutdown_timeout`` in :func:`run_app`. + +Another problem is if the application supports :term:`websockets ` or +*data streaming* it most likely has open connections at server shutdown time. The *library* has no knowledge how to close them gracefully but diff --git a/docs/web_reference.rst b/docs/web_reference.rst index f37c3da854b..05210f17199 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -2688,9 +2688,10 @@ application on specific TCP or Unix socket, e.g.:: :param int port: PORT to listed on, ``8080`` if ``None`` (default). - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP @@ -2723,9 +2724,10 @@ application on specific TCP or Unix socket, e.g.:: :param str path: PATH to UNIX socket to listen. - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP @@ -2745,9 +2747,10 @@ application on specific TCP or Unix socket, e.g.:: :param str path: PATH of named pipe to listen. - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. .. class:: SockSite(runner, sock, *, \ shutdown_timeout=60.0, ssl_context=None, \ @@ -2759,9 +2762,10 @@ application on specific TCP or Unix socket, e.g.:: :param sock: A :ref:`socket instance ` to listen to. - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP @@ -2857,9 +2861,13 @@ Utilities shutdown before disconnecting all open client sockets hard way. + This is used as a delay to wait for + pending tasks to complete and then + again to close any pending connections. + A system with properly :ref:`aiohttp-web-graceful-shutdown` - implemented never waits for this + implemented never waits for the second timeout but closes a server in a few milliseconds. diff --git a/tests/test_run_app.py b/tests/test_run_app.py index 46b868c3815..d7ab4b9de4b 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -9,14 +9,15 @@ import ssl import subprocess import sys -from typing import Any +import time +from typing import Any, Callable, NoReturn from unittest import mock from uuid import uuid4 import pytest from conftest import needs_unix -from aiohttp import web +from aiohttp import ClientConnectorError, ClientSession, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -926,3 +927,197 @@ async def init(): web.run_app(init(), print=stopper(patched_loop), loop=patched_loop) assert count == 3 + + +class TestShutdown: + def raiser(self) -> NoReturn: + raise KeyboardInterrupt + + async def stop(self, request: web.Request) -> web.Response: + asyncio.get_running_loop().call_soon(self.raiser) + return web.Response() + + def run_app(self, port: int, timeout: int, task, extra_test=None) -> asyncio.Task: + async def test() -> None: + await asyncio.sleep(1) + async with ClientSession() as sess: + async with sess.get(f"http://localhost:{port}/"): + pass + async with sess.get(f"http://localhost:{port}/stop"): + pass + + if extra_test: + await extra_test(sess) + + async def run_test(app: web.Application) -> None: + nonlocal test_task + test_task = asyncio.create_task(test()) + yield + await test_task + + async def handler(request: web.Request) -> web.Response: + nonlocal t + t = asyncio.create_task(task()) + return web.Response(text="FOO") + + t = test_task = None + app = web.Application() + app.cleanup_ctx.append(run_test) + app.router.add_get("/", handler) + app.router.add_get("/stop", self.stop) + + web.run_app(app, port=port, shutdown_timeout=timeout) + assert test_task.exception() is None + return t + + def test_shutdown_wait_for_task( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task(): + nonlocal finished + await asyncio.sleep(2) + finished = True + + t = self.run_app(port, 3, task) + + assert finished is True + assert t.done() + assert not t.cancelled() + + def test_shutdown_timeout_task( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task(): + nonlocal finished + await asyncio.sleep(2) + finished = True + + t = self.run_app(port, 1, task) + + assert finished is False + assert t.done() + assert t.cancelled() + + def test_shutdown_wait_for_spawned_task( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + finished_sub = False + sub_t = None + + async def sub_task(): + nonlocal finished_sub + await asyncio.sleep(1.5) + finished_sub = True + + async def task(): + nonlocal finished, sub_t + await asyncio.sleep(0.5) + sub_t = asyncio.create_task(sub_task()) + finished = True + + t = self.run_app(port, 3, task) + + assert finished is True + assert t.done() + assert not t.cancelled() + assert finished_sub is True + assert sub_t.done() + assert not sub_t.cancelled() + + def test_shutdown_timeout_not_reached( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task(): + nonlocal finished + await asyncio.sleep(1) + finished = True + + start_time = time.time() + t = self.run_app(port, 15, task) + + assert finished is True + assert t.done() + # Verify run_app has not waited for timeout. + assert time.time() - start_time < 10 + + def test_shutdown_new_conn_rejected( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task() -> None: + nonlocal finished + await asyncio.sleep(9) + finished = True + + async def test(sess: ClientSession) -> None: + # Ensure we are in the middle of shutdown (waiting for task()). + await asyncio.sleep(1) + with pytest.raises(ClientConnectorError): + # Use a new session to try and open a new connection. + async with ClientSession() as sess: + async with sess.get(f"http://localhost:{port}/"): + pass + assert finished is False + + t = self.run_app(port, 10, task, test) + + assert finished is True + assert t.done() + + def test_shutdown_pending_handler_responds( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def test() -> None: + async def test_resp(sess): + async with sess.get(f"http://localhost:{port}/") as resp: + assert await resp.text() == "FOO" + + await asyncio.sleep(1) + async with ClientSession() as sess: + t = asyncio.create_task(test_resp(sess)) + await asyncio.sleep(1) + # Handler is in-progress while we trigger server shutdown. + async with sess.get(f"http://localhost:{port}/stop"): + pass + + assert finished is False + # Handler should still complete and produce a response. + await t + + async def run_test(app: web.Application) -> None: + nonlocal t + t = asyncio.create_task(test()) + yield + await t + + async def handler(request: web.Request) -> web.Response: + nonlocal finished + await asyncio.sleep(3) + finished = True + return web.Response(text="FOO") + + t = None + app = web.Application() + app.cleanup_ctx.append(run_test) + app.router.add_get("/", handler) + app.router.add_get("/stop", self.stop) + + web.run_app(app, port=port, shutdown_timeout=5) + assert t.exception() is None + assert finished is True