From dd175b6b89564dc74fba0692a8a5f9a9b38e528a Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sun, 26 Nov 2023 15:22:31 +0000 Subject: [PATCH] Fix regression with connection upgrade (#7879) (#7908) Fixes #7867. (cherry picked from commit 48b15583305e692ce997ec6f5a6a2f88f23ace71) --- CHANGES/7879.bugfix | 1 + aiohttp/client_reqrep.py | 19 ++++++++----------- aiohttp/connector.py | 4 ++++ tests/test_client_functional.py | 19 +++++++++++++++++++ 4 files changed, 32 insertions(+), 11 deletions(-) create mode 100644 CHANGES/7879.bugfix diff --git a/CHANGES/7879.bugfix b/CHANGES/7879.bugfix new file mode 100644 index 00000000000..08baf85be42 --- /dev/null +++ b/CHANGES/7879.bugfix @@ -0,0 +1 @@ +Fixed a regression where connection may get closed during upgrade. -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 0ab84743658..1d946aea320 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -1006,19 +1006,14 @@ def _response_eof(self) -> None: if self._closed: return - if self._connection is not None: - # websocket, protocol could be None because - # connection could be detached - if ( - self._connection.protocol is not None - and self._connection.protocol.upgraded - ): - return - - self._release_connection() + # protocol could be None because connection could be detached + protocol = self._connection and self._connection.protocol + if protocol is not None and protocol.upgraded: + return self._closed = True self._cleanup_writer() + self._release_connection() @property def closed(self) -> bool: @@ -1113,7 +1108,9 @@ async def read(self) -> bytes: elif self._released: # Response explicitly released raise ClientConnectionError("Connection closed") - await self._wait_released() # Underlying connection released + protocol = self._connection and self._connection.protocol + if protocol is None or not protocol.upgraded: + await self._wait_released() # Underlying connection released return self._body # type: ignore[no-any-return] def get_encoding(self) -> str: diff --git a/aiohttp/connector.py b/aiohttp/connector.py index d85679f8bca..61c26430860 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -127,6 +127,10 @@ def __del__(self, _warnings: Any = warnings) -> None: context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) + def __bool__(self) -> Literal[True]: + """Force subclasses to not be falsy, to make checks simpler.""" + return True + @property def loop(self) -> asyncio.AbstractEventLoop: warnings.warn( diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 6698ac6ef52..8a9a4e184be 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -173,6 +173,25 @@ async def handler(request): assert 1 == len(client._session.connector._conns) +async def test_upgrade_connection_not_released_after_read(aiohttp_client) -> None: + async def handler(request: web.Request) -> web.Response: + body = await request.read() + assert b"" == body + return web.Response( + status=101, headers={"Connection": "Upgrade", "Upgrade": "tcp"} + ) + + app = web.Application() + app.router.add_route("GET", "/", handler) + + client = await aiohttp_client(app) + + resp = await client.get("/") + await resp.read() + assert resp.connection is not None + assert not resp.closed + + async def test_keepalive_server_force_close_connection(aiohttp_client) -> None: async def handler(request): body = await request.read()