From 8f2b3dc66d267a1aa28aceb7622cac7ec54412da Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 9 Dec 2023 13:49:55 -1000 Subject: [PATCH] Implement happy eyeballs fixes #4451 --- aiohttp/connector.py | 55 +++++++++++++++++++++++++----------- aiohttp/helpers.py | 21 ++++++++++++++ requirements/constraints.txt | 6 ++-- requirements/dev.txt | 6 ++-- requirements/test.txt | 4 +-- setup.cfg | 1 + 6 files changed, 68 insertions(+), 25 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index fa96c592f56..bd5f38a3beb 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -3,6 +3,7 @@ import functools import logging import random +import socket import sys import traceback import warnings @@ -13,10 +14,9 @@ from itertools import cycle, islice from time import monotonic from types import TracebackType -from typing import ( # noqa +from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, DefaultDict, Dict, @@ -31,6 +31,8 @@ cast, ) +import aiohappyeyeballs + from . import hdrs, helpers from .abc import AbstractResolver from .client_exceptions import ( @@ -956,16 +958,33 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, + addr_infos: List[aiohappyeyeballs.AddrInfoType], req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, **kwargs: Any, ) -> Tuple[asyncio.Transport, ResponseHandler]: + local_addrs_infos = None + if self._local_addr: + host, port = self._local_addr + is_ipv6 = helpers.is_ipv6_address(host) + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + if is_ipv6: + addr = (host, port, 0, 0) + else: + addr = (host, port) + local_addrs_infos = [(family, 0, 0, addr)] try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): - return await self._loop.create_connection(*args, **kwargs) + sock = await aiohappyeyeballs.start_connection( + addr_infos=addr_infos, + local_addr_infos=local_addrs_infos, + happy_eyeballs_delay=0.25, + loop=self._loop, + ) + return await self._loop.create_connection(*args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: @@ -1120,36 +1139,27 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: raise ClientConnectorError(req.connection_key, exc) from exc last_exc: Optional[Exception] = None - - for hinfo in hosts: - host = hinfo["host"] - port = hinfo["port"] - + addr_infos = helpers.convert_hosts_to_addr_infos(hosts) + while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. # See https://github.com/aio-libs/aiohttp/issues/3636 server_hostname = ( - (req.server_hostname or hinfo["hostname"]).rstrip(".") - if sslcontext - else None + (req.server_hostname or host).rstrip(".") if sslcontext else None ) try: transp, proto = await self._wrap_create_connection( self._factory, - host, - port, timeout=timeout, ssl=sslcontext, - family=hinfo["family"], - proto=hinfo["proto"], - flags=hinfo["flags"], + addr_infos=addr_infos, server_hostname=server_hostname, - local_addr=self._local_addr, req=req, client_error=client_error, ) except ClientConnectorError as exc: last_exc = exc + addr_infos.pop(0) continue if req.is_ssl() and fingerprint: @@ -1160,6 +1170,17 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: if not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transp) last_exc = exc + sock: socket.socket = transp.get_extra_info("socket") + bad_peer = sock.getpeername() + bad_addrs_infos: List[aiohappyeyeballs.AddrInfoType] = [] + for addr_info in addr_infos: + if addr_info[-1][0] == bad_peer[0]: + bad_addrs_infos.append(addr_info) + if bad_addrs_infos: + for bad_addr_info in bad_addrs_infos: + addr_infos.remove(bad_addr_info) + else: + addr_infos.pop(0) continue return transp, proto diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 5435e2f9e07..388d7a630b3 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -13,6 +13,7 @@ import os import platform import re +import socket import sys import time import warnings @@ -1090,3 +1091,23 @@ def should_remove_content_length(method: str, code: int) -> bool: or 100 <= code < 200 or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT) ) + + +def convert_hosts_to_addr_infos(hosts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Converts the list of hosts to a list of addr_infos. + + The list of hosts is the result of a DNS lookup. The list of + addr_infos is the result of a call to `socket.getaddrinfo()`. + """ + addr_infos: List[Dict[str, Any]] = [] + for hinfo in hosts: + host = hinfo["host"] + is_ipv6 = is_ipv6_address(host) + family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + if is_ipv6: + addr = (host, hinfo["port"], 0, 0) + else: + addr = (host, hinfo["port"]) + addr_infos.append( + (family, hinfo["type"], hinfo["proto"], hinfo["hostname"], addr) + ) diff --git a/requirements/constraints.txt b/requirements/constraints.txt index f7b2d0e967e..3883129058e 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # -# pip-compile --allow-unsafe --output-file=requirements/constraints.txt --resolver=backtracking --strip-extras requirements/constraints.in +# pip-compile --allow-unsafe --output-file=requirements/constraints.txt --strip-extras requirements/constraints.in # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in diff --git a/requirements/dev.txt b/requirements/dev.txt index ea0783703a6..fe2c1ae31ab 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # -# pip-compile --allow-unsafe --output-file=requirements/dev.txt --resolver=backtracking --strip-extras requirements/dev.in +# pip-compile --allow-unsafe --output-file=requirements/dev.txt --strip-extras requirements/dev.in # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in diff --git a/requirements/test.txt b/requirements/test.txt index 3eba094b25f..0ad5da3c6b7 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with python 3.8 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # -# pip-compile --allow-unsafe --output-file=requirements/test.txt --resolver=backtracking --strip-extras requirements/test.in +# pip-compile --allow-unsafe --output-file=requirements/test.txt --strip-extras requirements/test.in # aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin" # via -r requirements/runtime-deps.in diff --git a/setup.cfg b/setup.cfg index 13efc0b7796..c2dbad60015 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ zip_safe = False include_package_data = True install_requires = + aiohappyeyeballs >= 1.7.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 5.0 ; python_version < "3.11" frozenlist >= 1.1.1