diff --git a/docs/advanced.md b/docs/advanced.md index b2a07df371..88dc377c51 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -626,6 +626,23 @@ For instance this request sends 2 files, `foo.png` and `bar.png` in one request >>> r = httpx.post("https://httpbin.org/post", files=files) ``` +## Retries + +Client instances can retry on errors that occur while establishing new connections. + +Retries are disabled by default. They can be enabled by passing the maximum number of retries as `Client(retries=)`. For example... + +```pycon +>>> with httpx.Client(retries=3) as client: +... # If a connect error occurs, we will retry up to 3 times before failing. +... response = client.get("https://unstableserver.com/") +``` + +Note that: + +* HTTPX issues a first retry without waiting (transient errors are often resolved immediately), then issues retries at exponentially increasing time intervals (0.5s, 1s, 2s, 4s, etc). +* The built-in retry functionality only applies to failures _while establishing new connections_ (effectively `ConnectError` and `ConnectTimeout` exceptions). In particular, errors while interacting with existing connections (such as unexpected server-side connection closures, or read/write timeouts) will not be retried on. + ## Customizing authentication When issuing requests or instantiating a client, the `auth` argument can be used to pass an authentication scheme to use. The `auth` argument may be one of the following... diff --git a/httpx/_client.py b/httpx/_client.py index d6a0caf085..698a577295 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -1,4 +1,5 @@ import functools +import time import typing import warnings from types import TracebackType @@ -10,7 +11,9 @@ from ._config import ( DEFAULT_LIMITS, DEFAULT_MAX_REDIRECTS, + DEFAULT_RETRIES, DEFAULT_TIMEOUT_CONFIG, + RETRIES_BACKOFF_FACTOR, UNSET, Limits, Proxy, @@ -22,6 +25,8 @@ from ._decoders import SUPPORTED_DECODERS from ._exceptions import ( HTTPCORE_EXC_MAP, + ConnectError, + ConnectTimeout, InvalidURL, RemoteProtocolError, RequestBodyUnavailable, @@ -48,9 +53,11 @@ from ._utils import ( NetRCInfo, URLPattern, + exponential_backoff, get_environment_proxies, get_logger, same_origin, + sleep, warn_deprecated, ) @@ -73,6 +80,7 @@ def __init__( cookies: CookieTypes = None, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, max_redirects: int = DEFAULT_MAX_REDIRECTS, + retries: int = DEFAULT_RETRIES, base_url: URLTypes = "", trust_env: bool = True, ): @@ -84,6 +92,7 @@ def __init__( self._cookies = Cookies(cookies) self._timeout = Timeout(timeout) self.max_redirects = max_redirects + self.retries = retries self._trust_env = trust_env self._netrc = NetRCInfo() self._is_closed = True @@ -506,6 +515,8 @@ class Client(BaseClient): * **limits** - *(optional)* The limits configuration to use. * **max_redirects** - *(optional)* The maximum number of redirect responses that should be followed. + * **retries** - *(optional)* The maximum number of retries when trying to + establish a connection. * **base_url** - *(optional)* A URL to use as the base when building request URLs. * **transport** - *(optional)* A transport class to use for sending requests @@ -531,6 +542,7 @@ def __init__( limits: Limits = DEFAULT_LIMITS, pool_limits: Limits = None, max_redirects: int = DEFAULT_MAX_REDIRECTS, + retries: int = DEFAULT_RETRIES, base_url: URLTypes = "", transport: httpcore.SyncHTTPTransport = None, app: typing.Callable = None, @@ -543,6 +555,7 @@ def __init__( cookies=cookies, timeout=timeout, max_redirects=max_redirects, + retries=retries, base_url=base_url, trust_env=trust_env, ) @@ -790,7 +803,7 @@ def _send_handling_auth( auth_flow = auth.auth_flow(request) request = next(auth_flow) while True: - response = self._send_single_request(request, timeout) + response = self._send_handling_retries(request, timeout) if auth.requires_response_body: response.read() try: @@ -806,6 +819,20 @@ def _send_handling_auth( request = next_request history.append(response) + def _send_handling_retries(self, request: Request, timeout: Timeout) -> Response: + retries_left = self.retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + return self._send_single_request(request, timeout) + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + time.sleep(delay) + def _send_single_request(self, request: Request, timeout: Timeout) -> Response: """ Sends a single request, without handling any redirections. @@ -1125,6 +1152,8 @@ class AsyncClient(BaseClient): * **limits** - *(optional)* The limits configuration to use. * **max_redirects** - *(optional)* The maximum number of redirect responses that should be followed. + * **retries** - *(optional)* The maximum number of retries when trying to + establish a connection. * **base_url** - *(optional)* A URL to use as the base when building request URLs. * **transport** - *(optional)* A transport class to use for sending requests @@ -1150,6 +1179,7 @@ def __init__( limits: Limits = DEFAULT_LIMITS, pool_limits: Limits = None, max_redirects: int = DEFAULT_MAX_REDIRECTS, + retries: int = DEFAULT_RETRIES, base_url: URLTypes = "", transport: httpcore.AsyncHTTPTransport = None, app: typing.Callable = None, @@ -1162,6 +1192,7 @@ def __init__( cookies=cookies, timeout=timeout, max_redirects=max_redirects, + retries=retries, base_url=base_url, trust_env=trust_env, ) @@ -1411,7 +1442,7 @@ async def _send_handling_auth( auth_flow = auth.auth_flow(request) request = next(auth_flow) while True: - response = await self._send_single_request(request, timeout) + response = await self._send_handling_retries(request, timeout) if auth.requires_response_body: await response.aread() try: @@ -1427,6 +1458,22 @@ async def _send_handling_auth( request = next_request history.append(response) + async def _send_handling_retries( + self, request: Request, timeout: Timeout + ) -> Response: + retries_left = self.retries + delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) + + while True: + try: + return await self._send_single_request(request, timeout) + except (ConnectError, ConnectTimeout): + if retries_left <= 0: + raise + retries_left -= 1 + delay = next(delays) + await sleep(delay) + async def _send_single_request( self, request: Request, timeout: Timeout ) -> Response: diff --git a/httpx/_config.py b/httpx/_config.py index 8d589eadec..5bedff3874 100644 --- a/httpx/_config.py +++ b/httpx/_config.py @@ -413,3 +413,5 @@ def __repr__(self) -> str: DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0) DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20) DEFAULT_MAX_REDIRECTS = 20 +DEFAULT_RETRIES = 0 +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. diff --git a/httpx/_utils.py b/httpx/_utils.py index 8080f63a46..0b47e108d9 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -1,5 +1,7 @@ +import asyncio import codecs import collections +import itertools import logging import mimetypes import netrc @@ -14,6 +16,8 @@ from types import TracebackType from urllib.request import getproxies +import sniffio + from ._types import PrimitiveData if typing.TYPE_CHECKING: # pragma: no cover @@ -529,3 +533,24 @@ def __eq__(self, other: typing.Any) -> bool: def warn_deprecated(message: str) -> None: # pragma: nocover warnings.warn(message, DeprecationWarning, stacklevel=2) + + +async def sleep(seconds: float) -> None: + library = sniffio.current_async_library() + if library == "trio": + import trio + + await trio.sleep(seconds) + elif library == "curio": # pragma: no cover + import curio + + await curio.sleep(seconds) + else: + assert library == "asyncio" + await asyncio.sleep(seconds) + + +def exponential_backoff(factor: float) -> typing.Iterator[float]: + yield 0 + for n in itertools.count(2): + yield factor * (2 ** (n - 2)) diff --git a/tests/client/test_retries.py b/tests/client/test_retries.py new file mode 100644 index 0000000000..065b32c121 --- /dev/null +++ b/tests/client/test_retries.py @@ -0,0 +1,173 @@ +import collections +from typing import Dict, List, Mapping, Optional, Tuple + +import httpcore +import pytest + +import httpx + + +def test_retries_config() -> None: + client = httpx.Client() + assert client.retries == 0 + + client = httpx.Client(retries=3) + assert client.retries == 3 + + client.retries = 1 + assert client.retries == 1 + + +class BaseMockTransport: + def __init__(self, num_failures: int) -> None: + self._num_failures = num_failures + self._attempts_by_path: Dict[bytes, int] = collections.defaultdict(int) + + def _request( + self, + url: Tuple[bytes, bytes, Optional[int], bytes], + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.PlainByteStream]: + _, _, _, path = url + + exc, is_retryable = { + b"/": (None, False), + b"/connect_timeout": (httpcore.ConnectTimeout, True), + b"/connect_error": (httpcore.ConnectError, True), + b"/read_timeout": (httpcore.ReadTimeout, False), + b"/network_error": (httpcore.NetworkError, False), + }[path] + + if exc is None: + stream = httpcore.PlainByteStream(b"") + return (b"HTTP/1.1", 200, b"OK", [], stream) + + if not is_retryable: + raise exc + + if self._attempts_by_path[path] >= self._num_failures: + self._attempts_by_path.clear() + stream = httpcore.PlainByteStream(b"") + return (b"HTTP/1.1", 200, b"OK", [], stream) + + self._attempts_by_path[path] += 1 + + raise exc + + +class MockTransport(BaseMockTransport, httpcore.SyncHTTPTransport): + def __init__(self, num_failures: int) -> None: + super().__init__(num_failures) + + def request( + self, + method: bytes, + url: Tuple[bytes, bytes, Optional[int], bytes], + headers: List[Tuple[bytes, bytes]] = None, + stream: httpcore.SyncByteStream = None, + timeout: Mapping[str, Optional[float]] = None, + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]: + return self._request(url) + + +class AsyncMockTransport(BaseMockTransport, httpcore.AsyncHTTPTransport): + def __init__(self, num_failures: int) -> None: + super().__init__(num_failures) + + async def request( + self, + method: bytes, + url: Tuple[bytes, bytes, Optional[int], bytes], + headers: List[Tuple[bytes, bytes]] = None, + stream: httpcore.AsyncByteStream = None, + timeout: Mapping[str, Optional[float]] = None, + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]: + return self._request(url) + + +def test_no_retries() -> None: + """ + By default, connection failures are not retried on. + """ + transport = MockTransport(num_failures=1) + client = httpx.Client(transport=transport) + + response = client.get("https://example.com") + assert response.status_code == 200 + + with pytest.raises(httpx.ConnectTimeout): + client.get("https://example.com/connect_timeout") + + with pytest.raises(httpx.ConnectError): + client.get("https://example.com/connect_error") + + +def test_retries_enabled() -> None: + """ + When retries are enabled, connection failures are retried on with + a fixed exponential backoff. + """ + transport = MockTransport(num_failures=3) + client = httpx.Client(transport=transport, retries=3) + expected_elapsed_time = pytest.approx(0 + 0.5 + 1, rel=0.1) + + response = client.get("https://example.com") + assert response.status_code == 200 + + response = client.get("https://example.com/connect_timeout") + assert response.status_code == 200 + assert response.elapsed.total_seconds() == expected_elapsed_time + + response = client.get("https://example.com/connect_error") + assert response.status_code == 200 + assert response.elapsed.total_seconds() == expected_elapsed_time + + with pytest.raises(httpx.ReadTimeout): + client.get("https://example.com/read_timeout") + + with pytest.raises(httpx.NetworkError): + client.get("https://example.com/network_error") + + +@pytest.mark.usefixtures("async_environment") +async def test_retries_enabled_async() -> None: + # For test coverage purposes. + transport = AsyncMockTransport(num_failures=3) + client = httpx.AsyncClient(transport=transport, retries=3) + expected_elapsed_time = pytest.approx(0 + 0.5 + 1, rel=0.1) + + # Connect exceptions are retried on with a backoff. + response = await client.get("https://example.com/connect_timeout") + assert response.status_code == 200 + assert response.elapsed.total_seconds() == expected_elapsed_time + + # Non-connect errors are not retried on. + with pytest.raises(httpx.ReadTimeout): + await client.get("https://example.com/read_timeout") + + +def test_retries_exceeded() -> None: + """ + When retries are enabled and connecting failures more than the configured number + of retries, connect exceptions are raised. + """ + transport = MockTransport(num_failures=2) + client = httpx.Client(transport=transport, retries=1) + + with pytest.raises(httpx.ConnectTimeout): + client.get("https://example.com/connect_timeout") + + with pytest.raises(httpx.ConnectError): + client.get("https://example.com/connect_error") + + +@pytest.mark.parametrize( + "method", ["HEAD", "GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "TRACE"] +) +def test_retries_methods(method: str) -> None: + """ + Client retries on all HTTP methods. + """ + transport = MockTransport(num_failures=1) + client = httpx.Client(transport=transport, retries=1) + response = client.request(method, "https://example.com/connect_timeout") + assert response.status_code == 200 diff --git a/tests/test_utils.py b/tests/test_utils.py index d5dfb5819b..cb58300061 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ import asyncio +import itertools import os import random +from typing import List import pytest @@ -9,6 +11,7 @@ ElapsedTimer, NetRCInfo, URLPattern, + exponential_backoff, get_ca_bundle_from_env, get_environment_proxies, guess_json_utf, @@ -270,3 +273,16 @@ def test_pattern_priority(): URLPattern("http://"), URLPattern("all://"), ] + + +@pytest.mark.parametrize( + "factor, expected", + [ + (0.1, [0, 0.1, 0.2, 0.4, 0.8]), + (0.2, [0, 0.2, 0.4, 0.8, 1.6]), + (0.5, [0, 0.5, 1.0, 2.0, 4.0]), + ], +) +def test_exponential_backoff(factor: float, expected: List[int]) -> None: + delays = list(itertools.islice(exponential_backoff(factor), 5)) + assert delays == expected