From d32d5805e39fa8b75f7a688788de5b73b6ce0077 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 27 Sep 2024 12:41:34 -0500 Subject: [PATCH] [PR #9301/c240b52 backport][3.10] Replace code that can now be handled by yarl (#9314) --- CHANGES/9301.misc.rst | 1 + aiohttp/client_reqrep.py | 32 +++++++++++++++++++++----------- tests/test_client_request.py | 15 +++++++++++---- tests/test_proxy_functional.py | 24 ++++++++++++++++-------- 4 files changed, 49 insertions(+), 23 deletions(-) create mode 100644 CHANGES/9301.misc.rst diff --git a/CHANGES/9301.misc.rst b/CHANGES/9301.misc.rst new file mode 100644 index 00000000000..a751bdfc6dc --- /dev/null +++ b/CHANGES/9301.misc.rst @@ -0,0 +1 @@ +Replaced code that can now be handled by ``yarl`` -- by :user:`bdraco`. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index aeabf0161b8..10682f57885 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -27,7 +27,7 @@ import attr from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy -from yarl import URL +from yarl import URL, __version__ as yarl_version from . import hdrs, helpers, http, multipart, payload from .abc import AbstractStreamWriter @@ -90,6 +90,10 @@ _CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json") +_YARL_SUPPORTS_HOST_SUBCOMPONENT = tuple(map(int, yarl_version.split(".")[:2])) >= ( + 1, + 13, +) def _gen_default_accept_encoding() -> str: @@ -429,9 +433,13 @@ def update_headers(self, headers: Optional[LooseHeaders]) -> None: self.headers: CIMultiDict[str] = CIMultiDict() # add host - netloc = cast(str, self.url.raw_host) - if helpers.is_ipv6_address(netloc): - netloc = f"[{netloc}]" + if _YARL_SUPPORTS_HOST_SUBCOMPONENT: + netloc = self.url.host_subcomponent + assert netloc is not None + else: + netloc = cast(str, self.url.raw_host) + if helpers.is_ipv6_address(netloc): + netloc = f"[{netloc}]" # See https://github.com/aio-libs/aiohttp/issues/3636. netloc = netloc.rstrip(".") if self.url.port is not None and not self.url.is_default_port(): @@ -676,17 +684,19 @@ async def send(self, conn: "Connection") -> "ClientResponse": # - not CONNECT proxy must send absolute form URI # - most common is origin form URI if self.method == hdrs.METH_CONNECT: - connect_host = self.url.raw_host - assert connect_host is not None - if helpers.is_ipv6_address(connect_host): - connect_host = f"[{connect_host}]" + if _YARL_SUPPORTS_HOST_SUBCOMPONENT: + connect_host = self.url.host_subcomponent + assert connect_host is not None + else: + connect_host = self.url.raw_host + assert connect_host is not None + if helpers.is_ipv6_address(connect_host): + connect_host = f"[{connect_host}]" path = f"{connect_host}:{self.url.port}" elif self.proxy and not self.is_ssl(): path = str(self.url) else: - path = self.url.raw_path - if self.url.raw_query_string: - path += "?" + self.url.raw_query_string + path = self.url.raw_path_qs protocol = conn.protocol assert protocol is not None diff --git a/tests/test_client_request.py b/tests/test_client_request.py index f2eff019504..7853b541fc9 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -14,7 +14,7 @@ from yarl import URL import aiohttp -from aiohttp import BaseConnector, hdrs, helpers, payload +from aiohttp import BaseConnector, client_reqrep, hdrs, helpers, payload from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_reqrep import ( ClientRequest, @@ -280,9 +280,16 @@ def test_host_header_ipv4(make_request) -> None: assert req.headers["HOST"] == "127.0.0.2" -def test_host_header_ipv6(make_request) -> None: - req = make_request("get", "http://[::2]") - assert req.headers["HOST"] == "[::2]" +@pytest.mark.parametrize("yarl_supports_host_subcomponent", [True, False]) +def test_host_header_ipv6(make_request, yarl_supports_host_subcomponent: bool) -> None: + # Ensure the old path is tested for old yarl versions + with mock.patch.object( + client_reqrep, + "_YARL_SUPPORTS_HOST_SUBCOMPONENT", + yarl_supports_host_subcomponent, + ): + req = make_request("get", "http://[::2]") + assert req.headers["HOST"] == "[::2]" def test_host_header_ipv4_with_port(make_request) -> None: diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index c15ca326288..4b11d11e3a7 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -12,7 +12,7 @@ from yarl import URL import aiohttp -from aiohttp import web +from aiohttp import client_reqrep, web from aiohttp.client_exceptions import ClientConnectionError from aiohttp.helpers import IS_MACOS, IS_WINDOWS @@ -95,6 +95,7 @@ async def handler(*args, **kwargs): reason="asyncio on this python does not support TLS in TLS", ) @pytest.mark.parametrize("web_server_endpoint_type", ("http", "https")) +@pytest.mark.parametrize("yarl_supports_host_subcomponent", [True, False]) @pytest.mark.filterwarnings(r"ignore:.*ssl.OP_NO_SSL*") # Filter out the warning from # https://github.com/abhinavsingh/proxy.py/blob/30574fd0414005dfa8792a6e797023e862bdcf43/proxy/common/utils.py#L226 @@ -104,18 +105,25 @@ async def test_secure_https_proxy_absolute_path( secure_proxy_url: URL, web_server_endpoint_url: str, web_server_endpoint_payload: str, + yarl_supports_host_subcomponent: bool, ) -> None: """Ensure HTTP(S) sites are accessible through a secure proxy.""" conn = aiohttp.TCPConnector() sess = aiohttp.ClientSession(connector=conn) - async with sess.get( - web_server_endpoint_url, - proxy=secure_proxy_url, - ssl=client_ssl_ctx, # used for both proxy and endpoint connections - ) as response: - assert response.status == 200 - assert await response.text() == web_server_endpoint_payload + # Ensure the old path is tested for old yarl versions + with mock.patch.object( + client_reqrep, + "_YARL_SUPPORTS_HOST_SUBCOMPONENT", + yarl_supports_host_subcomponent, + ): + async with sess.get( + web_server_endpoint_url, + proxy=secure_proxy_url, + ssl=client_ssl_ctx, # used for both proxy and endpoint connections + ) as response: + assert response.status == 200 + assert await response.text() == web_server_endpoint_payload await sess.close() await conn.close()