diff --git a/CHANGES/8672.bugfix.rst b/CHANGES/8672.bugfix.rst new file mode 100644 index 00000000000..a57ed16d5d2 --- /dev/null +++ b/CHANGES/8672.bugfix.rst @@ -0,0 +1,3 @@ +Fixed :py:class:`aiohttp.TCPConnector` doing blocking I/O in the event loop to create the ``SSLContext`` -- by :user:`bdraco`. + +The blocking I/O would only happen once per verify mode. However, it could cause the event loop to block for a long time if the ``SSLContext`` creation is slow, which is more likely during startup when the disk cache is not yet present. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index c86855a361f..e4a0b98b2df 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -51,7 +51,14 @@ ) from .client_proto import ResponseHandler from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint -from .helpers import _SENTINEL, ceil_timeout, is_ip_address, sentinel, set_result +from .helpers import ( + _SENTINEL, + ceil_timeout, + is_ip_address, + sentinel, + set_exception, + set_result, +) from .locks import EventResultOrError from .resolver import DefaultResolver @@ -752,6 +759,7 @@ class TCPConnector(BaseConnector): """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) + _made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {} def __init__( self, @@ -946,29 +954,24 @@ async def _create_connection( return proto @staticmethod - @functools.lru_cache(None) def _make_ssl_context(verified: bool) -> SSLContext: + """Create SSL context. + + This method is not async-friendly and should be called from a thread + because it will load certificates from disk and do other blocking I/O. + """ if verified: return ssl.create_default_context() - else: - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.check_hostname = False - sslcontext.verify_mode = ssl.CERT_NONE - try: - sslcontext.options |= ssl.OP_NO_COMPRESSION - except AttributeError as attr_err: - warnings.warn( - "{!s}: The Python interpreter is compiled " - "against OpenSSL < 1.0.0. Ref: " - "https://docs.python.org/3/library/ssl.html" - "#ssl.OP_NO_COMPRESSION".format(attr_err), - ) - sslcontext.set_default_verify_paths() - return sslcontext - - def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.check_hostname = False + sslcontext.verify_mode = ssl.CERT_NONE + sslcontext.options |= ssl.OP_NO_COMPRESSION + sslcontext.set_default_verify_paths() + return sslcontext + + async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -982,25 +985,46 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: 3. if verify_ssl is False in req, generate a SSL context that won't verify """ - if req.is_ssl(): - if ssl is None: # pragma: no cover - raise RuntimeError("SSL is not supported.") - sslcontext = req.ssl - if isinstance(sslcontext, ssl.SSLContext): - return sslcontext - if sslcontext is not True: - # not verified or fingerprinted - return self._make_ssl_context(False) - sslcontext = self._ssl - if isinstance(sslcontext, ssl.SSLContext): - return sslcontext - if sslcontext is not True: - # not verified or fingerprinted - return self._make_ssl_context(False) - return self._make_ssl_context(True) - else: + if not req.is_ssl(): return None + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + sslcontext = req.ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return await self._make_or_get_ssl_context(False) + sslcontext = self._ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return await self._make_or_get_ssl_context(False) + return await self._make_or_get_ssl_context(True) + + async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext: + """Create or get cached SSL context.""" + try: + return await self._made_ssl_context[verified] + except KeyError: + loop = self._loop + future = loop.create_future() + self._made_ssl_context[verified] = future + try: + result = await loop.run_in_executor( + None, self._make_ssl_context, verified + ) + # BaseException is used since we might get CancelledError + except BaseException as ex: + del self._made_ssl_context[verified] + set_exception(future, ex) + raise + else: + set_result(future, result) + return result + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): @@ -1092,7 +1116,7 @@ async def _start_tls_connection( # `req.is_ssl()` evaluates to `False` which is never gonna happen # in this code path. Of course, it's rather fragile # maintainability-wise but this is to be solved separately. - sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req)) + sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req)) try: async with ceil_timeout( @@ -1170,7 +1194,7 @@ async def _create_direct_connection( *, client_error: Type[Exception] = ClientConnectorError, ) -> Tuple[asyncio.Transport, ResponseHandler]: - sslcontext = self._get_ssl_context(req) + sslcontext = await self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) host = req.url.raw_host diff --git a/tests/test_connector.py b/tests/test_connector.py index 335e2a1ebc0..0e5b7dde2c6 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1709,67 +1709,109 @@ async def test_tcp_connector_clear_dns_cache_bad_args( conn.clear_dns_cache("localhost") -async def test_dont_recreate_ssl_context(loop: asyncio.AbstractEventLoop) -> None: +async def test_dont_recreate_ssl_context() -> None: conn = aiohttp.TCPConnector() - ctx = conn._make_ssl_context(True) - assert ctx is conn._make_ssl_context(True) + ctx = await conn._make_or_get_ssl_context(True) + assert ctx is await conn._make_or_get_ssl_context(True) -async def test_dont_recreate_ssl_context2(loop: asyncio.AbstractEventLoop) -> None: +async def test_dont_recreate_ssl_context2() -> None: conn = aiohttp.TCPConnector() - ctx = conn._make_ssl_context(False) - assert ctx is conn._make_ssl_context(False) + ctx = await conn._make_or_get_ssl_context(False) + assert ctx is await conn._make_or_get_ssl_context(False) -async def test___get_ssl_context1(loop: asyncio.AbstractEventLoop) -> None: +async def test___get_ssl_context1() -> None: conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = False - assert conn._get_ssl_context(req) is None + assert await conn._get_ssl_context(req) is None -async def test___get_ssl_context2(loop: asyncio.AbstractEventLoop) -> None: +async def test___get_ssl_context2() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = True req.ssl = ctx - assert conn._get_ssl_context(req) is ctx + assert await conn._get_ssl_context(req) is ctx -async def test___get_ssl_context3(loop: asyncio.AbstractEventLoop) -> None: +async def test___get_ssl_context3() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector(ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert conn._get_ssl_context(req) is ctx + assert await conn._get_ssl_context(req) is ctx -async def test___get_ssl_context4(loop: asyncio.AbstractEventLoop) -> None: +async def test___get_ssl_context4() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector(ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True req.ssl = False - assert conn._get_ssl_context(req) is conn._make_ssl_context(False) + assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( + False + ) -async def test___get_ssl_context5(loop: asyncio.AbstractEventLoop) -> None: +async def test___get_ssl_context5() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) conn = aiohttp.TCPConnector(ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest()) - assert conn._get_ssl_context(req) is conn._make_ssl_context(False) + assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( + False + ) -async def test___get_ssl_context6(loop: asyncio.AbstractEventLoop) -> None: +async def test___get_ssl_context6() -> None: conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert conn._get_ssl_context(req) is conn._make_ssl_context(True) + assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True) + + +async def test_ssl_context_once() -> None: + """Test the ssl context is created only once and shared between connectors.""" + conn1 = aiohttp.TCPConnector() + conn2 = aiohttp.TCPConnector() + conn3 = aiohttp.TCPConnector() + + req = mock.Mock() + req.is_ssl.return_value = True + req.ssl = True + assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context( + True + ) + assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context( + True + ) + assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context( + True + ) + assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context + assert True in conn1._made_ssl_context + + +@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError]) +async def test_ssl_context_creation_raises(exception: BaseException) -> None: + """Test that we try again if SSLContext creation fails the first time.""" + conn = aiohttp.TCPConnector() + conn._made_ssl_context.clear() + + with mock.patch.object( + conn, "_make_ssl_context", side_effect=exception + ), pytest.raises( # type: ignore[call-overload] + exception + ): + await conn._make_or_get_ssl_context(True) + + assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext) async def test_close_twice(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 6488b17f17b..b0bfeeb362a 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -813,7 +813,7 @@ async def make_conn(): self.loop.start_tls.assert_called_with( mock.ANY, mock.ANY, - connector._make_ssl_context(True), + self.loop.run_until_complete(connector._make_or_get_ssl_context(True)), server_hostname="www.python.org", ssl_handshake_timeout=mock.ANY, )