Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TCPConnector doing blocking I/O in the event loop to create the SSLContext #8672

Merged
merged 29 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/8672.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fixed :py:class:`aiohttp.TCPConnector` doing blocking I/O in the event loop to create the ``SSLContext`` -- by :user:`bdraco`.

The blocking I/O would only happen once per verify mode. However, it could cause the event loop to block for a long time if the ``SSLContext`` creation is slow, which is more likely during startup when the disk cache is not yet present.
104 changes: 64 additions & 40 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@
)
from .client_proto import ResponseHandler
from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint
from .helpers import _SENTINEL, ceil_timeout, is_ip_address, sentinel, set_result
from .helpers import (
_SENTINEL,
ceil_timeout,
is_ip_address,
sentinel,
set_exception,
set_result,
)
from .locks import EventResultOrError
from .resolver import DefaultResolver

Expand Down Expand Up @@ -752,6 +759,7 @@ class TCPConnector(BaseConnector):
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
_made_ssl_context: Dict[bool, asyncio.Lock] = {}
bdraco marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
Expand Down Expand Up @@ -946,29 +954,24 @@ async def _create_connection(
return proto

@staticmethod
@functools.lru_cache(None)
def _make_ssl_context(verified: bool) -> SSLContext:
"""Create SSL context.

This method is not async-friendly and should be called from a thread
because it will load certificates from disk and do other blocking I/O.
"""
if verified:
return ssl.create_default_context()
else:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
try:
sslcontext.options |= ssl.OP_NO_COMPRESSION
except AttributeError as attr_err:
warnings.warn(
"{!s}: The Python interpreter is compiled "
"against OpenSSL < 1.0.0. Ref: "
"https://docs.python.org/3/library/ssl.html"
"#ssl.OP_NO_COMPRESSION".format(attr_err),
)
sslcontext.set_default_verify_paths()
return sslcontext

def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.options |= ssl.OP_NO_SSLv3
sslcontext.check_hostname = False
sslcontext.verify_mode = ssl.CERT_NONE
sslcontext.options |= ssl.OP_NO_COMPRESSION
sslcontext.set_default_verify_paths()
return sslcontext

async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context

0. if req.ssl is false, return None
Expand All @@ -982,25 +985,46 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
3. if verify_ssl is False in req, generate a SSL context that
won't verify
"""
if req.is_ssl():
if ssl is None: # pragma: no cover
raise RuntimeError("SSL is not supported.")
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
return self._make_ssl_context(True)
else:
if not req.is_ssl():
return None

if ssl is None: # pragma: no cover
raise RuntimeError("SSL is not supported.")
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not True:
# not verified or fingerprinted
return await self._make_or_get_ssl_context(False)
return await self._make_or_get_ssl_context(True)

async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext:
"""Create or get cached SSL context."""
try:
return await self._made_ssl_context[verified]
except KeyError:
loop = self._loop
future = loop.create_future()
self._made_ssl_context[verified] = future
try:
result = await self._loop.run_in_executor(
None, self._make_ssl_context, verified
)
# BaseException is used since we might get CancelledError
except BaseException as ex:
del self._made_ssl_context[verified]
set_exception(future, ex)
raise
else:
set_result(future, result)
return result

def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
if isinstance(ret, Fingerprint):
Expand Down Expand Up @@ -1092,7 +1116,7 @@ async def _start_tls_connection(
# `req.is_ssl()` evaluates to `False` which is never gonna happen
# in this code path. Of course, it's rather fragile
# maintainability-wise but this is to be solved separately.
sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req))
sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req))

try:
async with ceil_timeout(
Expand Down Expand Up @@ -1170,7 +1194,7 @@ async def _create_direct_connection(
*,
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.Transport, ResponseHandler]:
sslcontext = self._get_ssl_context(req)
sslcontext = await self._get_ssl_context(req)
fingerprint = self._get_fingerprint(req)

host = req.url.raw_host
Expand Down
46 changes: 36 additions & 10 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,21 +1711,21 @@ async def test_tcp_connector_clear_dns_cache_bad_args(

async def test_dont_recreate_ssl_context(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
ctx = conn._make_ssl_context(True)
assert ctx is conn._make_ssl_context(True)
ctx = await conn._make_or_get_ssl_context(True)
assert ctx is await conn._make_or_get_ssl_context(True)


async def test_dont_recreate_ssl_context2(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
ctx = conn._make_ssl_context(False)
assert ctx is conn._make_ssl_context(False)
ctx = await conn._make_or_get_ssl_context(False)
assert ctx is await conn._make_or_get_ssl_context(False)


async def test___get_ssl_context1(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = False
assert conn._get_ssl_context(req) is None
assert await conn._get_ssl_context(req) is None


async def test___get_ssl_context2(loop: asyncio.AbstractEventLoop) -> None:
Expand All @@ -1734,7 +1734,7 @@ async def test___get_ssl_context2(loop: asyncio.AbstractEventLoop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = ctx
assert conn._get_ssl_context(req) is ctx
assert await conn._get_ssl_context(req) is ctx


async def test___get_ssl_context3(loop: asyncio.AbstractEventLoop) -> None:
Expand All @@ -1743,7 +1743,7 @@ async def test___get_ssl_context3(loop: asyncio.AbstractEventLoop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert conn._get_ssl_context(req) is ctx
assert await conn._get_ssl_context(req) is ctx


async def test___get_ssl_context4(loop: asyncio.AbstractEventLoop) -> None:
Expand All @@ -1752,7 +1752,9 @@ async def test___get_ssl_context4(loop: asyncio.AbstractEventLoop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = False
assert conn._get_ssl_context(req) is conn._make_ssl_context(False)
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)


async def test___get_ssl_context5(loop: asyncio.AbstractEventLoop) -> None:
Expand All @@ -1761,15 +1763,39 @@ async def test___get_ssl_context5(loop: asyncio.AbstractEventLoop) -> None:
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest())
assert conn._get_ssl_context(req) is conn._make_ssl_context(False)
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(
False
)


async def test___get_ssl_context6(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert conn._get_ssl_context(req) is conn._make_ssl_context(True)
assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True)


async def test_ssl_context_once(loop: asyncio.AbstractEventLoop) -> None:
"""Test the ssl context is created only once and shared between connectors."""
conn1 = aiohttp.TCPConnector()
conn2 = aiohttp.TCPConnector()
conn3 = aiohttp.TCPConnector()

req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = True
assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context(
True
)
assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context
assert True in conn1._made_ssl_context


async def test_close_twice(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None:
Expand Down
Loading