From 5177268384bf94bc89b7f49c80faec6963aa0059 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 10 Aug 2024 09:44:10 -0500 Subject: [PATCH] Fix TCPConnector doing blocking I/O in the event loop to create the SSLContext (#8672) Co-authored-by: Sam Bull Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit c3219bf88c2a9381c50cd18a0fc1ad701e39bb9a) --- CHANGES/8672.bugfix.rst | 3 ++ aiohttp/connector.py | 104 ++++++++++++++++++++++++---------------- tests/test_connector.py | 78 +++++++++++++++++++++++------- tests/test_proxy.py | 2 +- 4 files changed, 128 insertions(+), 59 deletions(-) create mode 100644 CHANGES/8672.bugfix.rst diff --git a/CHANGES/8672.bugfix.rst b/CHANGES/8672.bugfix.rst new file mode 100644 index 00000000000..a57ed16d5d2 --- /dev/null +++ b/CHANGES/8672.bugfix.rst @@ -0,0 +1,3 @@ +Fixed :py:class:`aiohttp.TCPConnector` doing blocking I/O in the event loop to create the ``SSLContext`` -- by :user:`bdraco`. + +The blocking I/O would only happen once per verify mode. However, it could cause the event loop to block for a long time if the ``SSLContext`` creation is slow, which is more likely during startup when the disk cache is not yet present. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index d4691b10e6e..04115c36a24 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -50,7 +50,14 @@ ) from .client_proto import ResponseHandler from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params -from .helpers import ceil_timeout, is_ip_address, noop, sentinel +from .helpers import ( + ceil_timeout, + is_ip_address, + noop, + sentinel, + set_exception, + set_result, +) from .locks import EventResultOrError from .resolver import DefaultResolver @@ -771,6 +778,7 @@ class TCPConnector(BaseConnector): """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) + _made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {} def __init__( self, @@ -969,29 +977,24 @@ async def _create_connection( return proto @staticmethod - @functools.lru_cache(None) def _make_ssl_context(verified: bool) -> SSLContext: + """Create SSL context. + + This method is not async-friendly and should be called from a thread + because it will load certificates from disk and do other blocking I/O. + """ if verified: return ssl.create_default_context() - else: - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.check_hostname = False - sslcontext.verify_mode = ssl.CERT_NONE - try: - sslcontext.options |= ssl.OP_NO_COMPRESSION - except AttributeError as attr_err: - warnings.warn( - "{!s}: The Python interpreter is compiled " - "against OpenSSL < 1.0.0. Ref: " - "https://docs.python.org/3/library/ssl.html" - "#ssl.OP_NO_COMPRESSION".format(attr_err), - ) - sslcontext.set_default_verify_paths() - return sslcontext - - def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.check_hostname = False + sslcontext.verify_mode = ssl.CERT_NONE + sslcontext.options |= ssl.OP_NO_COMPRESSION + sslcontext.set_default_verify_paths() + return sslcontext + + async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1005,25 +1008,46 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: 3. if verify_ssl is False in req, generate a SSL context that won't verify """ - if req.is_ssl(): - if ssl is None: # pragma: no cover - raise RuntimeError("SSL is not supported.") - sslcontext = req.ssl - if isinstance(sslcontext, ssl.SSLContext): - return sslcontext - if sslcontext is not True: - # not verified or fingerprinted - return self._make_ssl_context(False) - sslcontext = self._ssl - if isinstance(sslcontext, ssl.SSLContext): - return sslcontext - if sslcontext is not True: - # not verified or fingerprinted - return self._make_ssl_context(False) - return self._make_ssl_context(True) - else: + if not req.is_ssl(): return None + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + sslcontext = req.ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return await self._make_or_get_ssl_context(False) + sslcontext = self._ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return await self._make_or_get_ssl_context(False) + return await self._make_or_get_ssl_context(True) + + async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext: + """Create or get cached SSL context.""" + try: + return await self._made_ssl_context[verified] + except KeyError: + loop = self._loop + future = loop.create_future() + self._made_ssl_context[verified] = future + try: + result = await loop.run_in_executor( + None, self._make_ssl_context, verified + ) + # BaseException is used since we might get CancelledError + except BaseException as ex: + del self._made_ssl_context[verified] + set_exception(future, ex) + raise + else: + set_result(future, result) + return result + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): @@ -1180,7 +1204,7 @@ async def _start_tls_connection( # `req.is_ssl()` evaluates to `False` which is never gonna happen # in this code path. Of course, it's rather fragile # maintainability-wise but this is to be solved separately. - sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req)) + sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req)) try: async with ceil_timeout( @@ -1258,7 +1282,7 @@ async def _create_direct_connection( *, client_error: Type[Exception] = ClientConnectorError, ) -> Tuple[asyncio.Transport, ResponseHandler]: - sslcontext = self._get_ssl_context(req) + sslcontext = await self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) host = req.url.raw_host diff --git a/tests/test_connector.py b/tests/test_connector.py index d146fb4ee51..0d6ca18ef53 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1540,23 +1540,23 @@ async def test_tcp_connector_clear_dns_cache_bad_args(loop) -> None: conn.clear_dns_cache("localhost") -async def test_dont_recreate_ssl_context(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) - ctx = conn._make_ssl_context(True) - assert ctx is conn._make_ssl_context(True) +async def test_dont_recreate_ssl_context() -> None: + conn = aiohttp.TCPConnector() + ctx = await conn._make_or_get_ssl_context(True) + assert ctx is await conn._make_or_get_ssl_context(True) -async def test_dont_recreate_ssl_context2(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) - ctx = conn._make_ssl_context(False) - assert ctx is conn._make_ssl_context(False) +async def test_dont_recreate_ssl_context2() -> None: + conn = aiohttp.TCPConnector() + ctx = await conn._make_or_get_ssl_context(False) + assert ctx is await conn._make_or_get_ssl_context(False) -async def test___get_ssl_context1(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) +async def test___get_ssl_context1() -> None: + conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = False - assert conn._get_ssl_context(req) is None + assert await conn._get_ssl_context(req) is None async def test___get_ssl_context2(loop) -> None: @@ -1565,7 +1565,7 @@ async def test___get_ssl_context2(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = ctx - assert conn._get_ssl_context(req) is ctx + assert await conn._get_ssl_context(req) is ctx async def test___get_ssl_context3(loop) -> None: @@ -1574,7 +1574,7 @@ async def test___get_ssl_context3(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert conn._get_ssl_context(req) is ctx + assert await conn._get_ssl_context(req) is ctx async def test___get_ssl_context4(loop) -> None: @@ -1583,7 +1583,9 @@ async def test___get_ssl_context4(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = False - assert conn._get_ssl_context(req) is conn._make_ssl_context(False) + assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( + False + ) async def test___get_ssl_context5(loop) -> None: @@ -1592,15 +1594,55 @@ async def test___get_ssl_context5(loop) -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest()) - assert conn._get_ssl_context(req) is conn._make_ssl_context(False) + assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( + False + ) -async def test___get_ssl_context6(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) +async def test___get_ssl_context6() -> None: + conn = aiohttp.TCPConnector() + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = True + assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True) + + +async def test_ssl_context_once() -> None: + """Test the ssl context is created only once and shared between connectors.""" + conn1 = aiohttp.TCPConnector() + conn2 = aiohttp.TCPConnector() + conn3 = aiohttp.TCPConnector() + req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert conn._get_ssl_context(req) is conn._make_ssl_context(True) + assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context( + True + ) + assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context( + True + ) + assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context( + True + ) + assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context + assert True in conn1._made_ssl_context + + +@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError]) +async def test_ssl_context_creation_raises(exception: BaseException) -> None: + """Test that we try again if SSLContext creation fails the first time.""" + conn = aiohttp.TCPConnector() + conn._made_ssl_context.clear() + + with mock.patch.object( + conn, "_make_ssl_context", side_effect=exception + ), pytest.raises( # type: ignore[call-overload] + exception + ): + await conn._make_or_get_ssl_context(True) + + assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext) async def test_close_twice(loop) -> None: diff --git a/tests/test_proxy.py b/tests/test_proxy.py index f335e42c254..c5e98deb8a5 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -817,7 +817,7 @@ async def make_conn(): self.loop.start_tls.assert_called_with( mock.ANY, mock.ANY, - connector._make_ssl_context(True), + self.loop.run_until_complete(connector._make_or_get_ssl_context(True)), server_hostname="www.python.org", ssl_handshake_timeout=mock.ANY, )