Skip to content

Commit

Permalink
Implement happy eyeballs
Browse files Browse the repository at this point in the history
fixes #4451
  • Loading branch information
bdraco committed Dec 9, 2023
1 parent d807956 commit 8f2b3dc
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 25 deletions.
55 changes: 38 additions & 17 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import logging
import random
import socket
import sys
import traceback
import warnings
Expand All @@ -13,10 +14,9 @@
from itertools import cycle, islice
from time import monotonic
from types import TracebackType
from typing import ( # noqa
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
DefaultDict,
Dict,
Expand All @@ -31,6 +31,8 @@
cast,
)

import aiohappyeyeballs

from . import hdrs, helpers
from .abc import AbstractResolver
from .client_exceptions import (
Expand Down Expand Up @@ -956,16 +958,33 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
addr_infos: List[aiohappyeyeballs.AddrInfoType],
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
**kwargs: Any,
) -> Tuple[asyncio.Transport, ResponseHandler]:
local_addrs_infos = None
if self._local_addr:
host, port = self._local_addr
is_ipv6 = helpers.is_ipv6_address(host)
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
if is_ipv6:
addr = (host, port, 0, 0)
else:
addr = (host, port)
local_addrs_infos = [(family, 0, 0, addr)]
try:
async with ceil_timeout(
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
):
return await self._loop.create_connection(*args, **kwargs)
sock = await aiohappyeyeballs.start_connection(
addr_infos=addr_infos,
local_addr_infos=local_addrs_infos,
happy_eyeballs_delay=0.25,
loop=self._loop,
)
return await self._loop.create_connection(*args, **kwargs, sock=sock)
except cert_errors as exc:
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
except ssl_errors as exc:
Expand Down Expand Up @@ -1120,36 +1139,27 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
raise ClientConnectorError(req.connection_key, exc) from exc

last_exc: Optional[Exception] = None

for hinfo in hosts:
host = hinfo["host"]
port = hinfo["port"]

addr_infos = helpers.convert_hosts_to_addr_infos(hosts)

Check warning

Code scanning / CodeQL

Use of the return value of a procedure Warning

The result of
convert_hosts_to_addr_infos
is used even though it is always None.
while addr_infos:
# Strip trailing dots, certificates contain FQDN without dots.
# See https://github.com/aio-libs/aiohttp/issues/3636
server_hostname = (
(req.server_hostname or hinfo["hostname"]).rstrip(".")
if sslcontext
else None
(req.server_hostname or host).rstrip(".") if sslcontext else None
)

try:
transp, proto = await self._wrap_create_connection(
self._factory,
host,
port,
timeout=timeout,
ssl=sslcontext,
family=hinfo["family"],
proto=hinfo["proto"],
flags=hinfo["flags"],
addr_infos=addr_infos,
server_hostname=server_hostname,
local_addr=self._local_addr,
req=req,
client_error=client_error,
)
except ClientConnectorError as exc:
last_exc = exc
addr_infos.pop(0)
continue

if req.is_ssl() and fingerprint:
Expand All @@ -1160,6 +1170,17 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
if not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transp)
last_exc = exc
sock: socket.socket = transp.get_extra_info("socket")
bad_peer = sock.getpeername()
bad_addrs_infos: List[aiohappyeyeballs.AddrInfoType] = []
for addr_info in addr_infos:
if addr_info[-1][0] == bad_peer[0]:
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
else:
addr_infos.pop(0)
continue

return transp, proto
Expand Down
21 changes: 21 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import platform
import re
import socket
import sys
import time
import warnings
Expand Down Expand Up @@ -1090,3 +1091,23 @@ def should_remove_content_length(method: str, code: int) -> bool:
or 100 <= code < 200
or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT)
)


def convert_hosts_to_addr_infos(hosts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Converts the list of hosts to a list of addr_infos.
The list of hosts is the result of a DNS lookup. The list of
addr_infos is the result of a call to `socket.getaddrinfo()`.
"""
addr_infos: List[Dict[str, Any]] = []
for hinfo in hosts:
host = hinfo["host"]
is_ipv6 = is_ipv6_address(host)
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
if is_ipv6:
addr = (host, hinfo["port"], 0, 0)
else:
addr = (host, hinfo["port"])
addr_infos.append(
(family, hinfo["type"], hinfo["proto"], hinfo["hostname"], addr)
)
6 changes: 3 additions & 3 deletions requirements/constraints.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --allow-unsafe --output-file=requirements/constraints.txt --resolver=backtracking --strip-extras requirements/constraints.in
# pip-compile --allow-unsafe --output-file=requirements/constraints.txt --strip-extras requirements/constraints.in
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
Expand Down
6 changes: 3 additions & 3 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --allow-unsafe --output-file=requirements/dev.txt --resolver=backtracking --strip-extras requirements/dev.in
# pip-compile --allow-unsafe --output-file=requirements/dev.txt --strip-extras requirements/dev.in
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
Expand Down
4 changes: 2 additions & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# This file is autogenerated by pip-compile with python 3.8
# This file is autogenerated by pip-compile with Python 3.8
# by the following command:
#
# pip-compile --allow-unsafe --output-file=requirements/test.txt --resolver=backtracking --strip-extras requirements/test.in
# pip-compile --allow-unsafe --output-file=requirements/test.txt --strip-extras requirements/test.in
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ zip_safe = False
include_package_data = True

install_requires =
aiohappyeyeballs >= 1.7.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 5.0 ; python_version < "3.11"
frozenlist >= 1.1.1
Expand Down

0 comments on commit 8f2b3dc

Please sign in to comment.