diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 2f439cf0..83e434ba 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -4,6 +4,8 @@ from types import TracebackType from typing import Iterable, Iterator, Optional, Type +from httpcore._utils import OverallTimeoutHandler + from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream from .._exceptions import ConnectError, ConnectTimeout @@ -105,6 +107,8 @@ async def _connect(self, request: Request) -> AsyncNetworkStream: sni_hostname = request.extensions.get("sni_hostname", None) timeout = timeouts.get("connect", None) + overall_timeout = OverallTimeoutHandler(timeouts) + retries_left = self._retries delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) @@ -115,11 +119,12 @@ async def _connect(self, request: Request) -> AsyncNetworkStream: "host": self._origin.host.decode("ascii"), "port": self._origin.port, "local_address": self._local_address, - "timeout": timeout, + "timeout": overall_timeout.get_minimum_timeout(timeout), "socket_options": self._socket_options, } async with Trace("connect_tcp", logger, request, kwargs) as trace: - stream = await self._network_backend.connect_tcp(**kwargs) + with overall_timeout: + stream = await self._network_backend.connect_tcp(**kwargs) trace.return_value = stream else: kwargs = { diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 214dfc4b..8daabade 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -3,6 +3,8 @@ from types import TracebackType from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type +from httpcore._utils import OverallTimeoutHandler + from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol @@ -174,6 +176,7 @@ async def handle_async_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("pool", None) + overall_timeout = OverallTimeoutHandler(timeouts) with self._optional_thread_lock: # Add the incoming request to our request queue. @@ -188,8 +191,11 @@ async def handle_async_request(self, request: Request) -> Response: closing = self._assign_requests_to_connections() await self._close_connections(closing) - # Wait until this request has an assigned connection. - connection = await pool_request.wait_for_connection(timeout=timeout) + with overall_timeout: + # Wait until this request has an assigned connection. + connection = await pool_request.wait_for_connection( + timeout=overall_timeout.get_minimum_timeout(timeout) + ) try: # Send the request on the assigned connection. diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 0493a923..8d5e0fc3 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -16,6 +16,8 @@ import h11 +from httpcore._utils import OverallTimeoutHandler + from .._backends.base import AsyncNetworkStream from .._exceptions import ( ConnectionNotAvailable, @@ -147,6 +149,7 @@ async def handle_async_request(self, request: Request) -> Response: async def _send_request_headers(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) + overall_timeout = OverallTimeoutHandler(timeouts) with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): event = h11.Request( @@ -154,18 +157,29 @@ async def _send_request_headers(self, request: Request) -> None: target=request.url.target, headers=request.headers, ) - await self._send_event(event, timeout=timeout) + with overall_timeout: + await self._send_event( + event, timeout=overall_timeout.get_minimum_timeout(timeout) + ) async def _send_request_body(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) + overall_timeout = OverallTimeoutHandler(timeouts) assert isinstance(request.stream, AsyncIterable) async for chunk in request.stream: event = h11.Data(data=chunk) - await self._send_event(event, timeout=timeout) - await self._send_event(h11.EndOfMessage(), timeout=timeout) + with overall_timeout: + await self._send_event( + event, timeout=overall_timeout.get_minimum_timeout(timeout) + ) + + with overall_timeout: + await self._send_event( + h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout) + ) async def _send_event( self, event: h11.Event, timeout: Optional[float] = None @@ -181,9 +195,13 @@ async def _receive_response_headers( ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) + overall_timeout = OverallTimeoutHandler(timeouts) while True: - event = await self._receive_event(timeout=timeout) + with overall_timeout: + event = await self._receive_event( + timeout=overall_timeout.get_minimum_timeout(timeout) + ) if isinstance(event, h11.Response): break if ( @@ -205,9 +223,12 @@ async def _receive_response_headers( async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) + overall_timeout = OverallTimeoutHandler(timeouts) while True: - event = await self._receive_event(timeout=timeout) + event = await self._receive_event( + timeout=overall_timeout.get_minimum_timeout(timeout) + ) if isinstance(event, h11.Data): yield bytes(event.data) elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index c201ee4c..af1fbbb6 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -10,6 +10,8 @@ import h2.exceptions import h2.settings +from httpcore._utils import OverallTimeoutHandler + from .._backends.base import AsyncNetworkStream from .._exceptions import ( ConnectionNotAvailable, @@ -430,12 +432,16 @@ async def _read_incoming_data( ) -> typing.List[h2.events.Event]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) + overall_timeout = OverallTimeoutHandler(timeouts) if self._read_exception is not None: raise self._read_exception # pragma: nocover try: - data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) + with overall_timeout: + data = await self._network_stream.read( + self.READ_NUM_BYTES, overall_timeout.get_minimum_timeout(timeout) + ) if data == b"": raise RemoteProtocolError("Server disconnected") except Exception as exc: @@ -458,6 +464,7 @@ async def _read_incoming_data( async def _write_outgoing_data(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) + overall_timeout = OverallTimeoutHandler(timeouts) async with self._write_lock: data_to_send = self._h2_state.data_to_send() @@ -466,7 +473,8 @@ async def _write_outgoing_data(self, request: Request) -> None: raise self._write_exception # pragma: nocover try: - await self._network_stream.write(data_to_send, timeout) + with overall_timeout: + await self._network_stream.write(data_to_send, timeout) except Exception as exc: # pragma: nocover # If we get a network error we should: # diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index 4aa7d874..07ac5e74 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -3,6 +3,8 @@ from base64 import b64encode from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from httpcore._utils import OverallTimeoutHandler + from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ProxyError from .._models import ( @@ -266,6 +268,7 @@ def __init__( async def handle_async_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("connect", None) + overall_timeout = OverallTimeoutHandler(timeouts) async with self._connect_lock: if not self._connected: @@ -311,10 +314,11 @@ async def handle_async_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, "server_hostname": self._remote_origin.host.decode("ascii"), - "timeout": timeout, + "timeout": overall_timeout.get_minimum_timeout(timeout), } async with Trace("start_tls", logger, request, kwargs) as trace: - stream = await stream.start_tls(**kwargs) + with overall_timeout: + stream = await stream.start_tls(**kwargs) trace.return_value = stream # Determine if we should be using HTTP/1.1 or HTTP/2 diff --git a/httpcore/_async/socks_proxy.py b/httpcore/_async/socks_proxy.py index f839603f..e03a2677 100644 --- a/httpcore/_async/socks_proxy.py +++ b/httpcore/_async/socks_proxy.py @@ -4,6 +4,8 @@ from socksio import socks5 +from httpcore._utils import OverallTimeoutHandler + from .._backends.auto import AutoBackend from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream from .._exceptions import ConnectionNotAvailable, ProxyError @@ -218,6 +220,7 @@ async def handle_async_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) sni_hostname = request.extensions.get("sni_hostname", None) timeout = timeouts.get("connect", None) + overall_timeout = OverallTimeoutHandler(timeouts) async with self._connect_lock: if self._connection is None: @@ -226,10 +229,11 @@ async def handle_async_request(self, request: Request) -> Response: kwargs = { "host": self._proxy_origin.host.decode("ascii"), "port": self._proxy_origin.port, - "timeout": timeout, + "timeout": overall_timeout.get_minimum_timeout(timeout), } async with Trace("connect_tcp", logger, request, kwargs) as trace: - stream = await self._network_backend.connect_tcp(**kwargs) + with overall_timeout: + stream = await self._network_backend.connect_tcp(**kwargs) trace.return_value = stream # Connect to the remote host using socks5 diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index c3890f34..1e724c42 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -4,6 +4,8 @@ from types import TracebackType from typing import Iterable, Iterator, Optional, Type +from httpcore._utils import OverallTimeoutHandler + from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream from .._exceptions import ConnectError, ConnectTimeout @@ -105,6 +107,8 @@ def _connect(self, request: Request) -> NetworkStream: sni_hostname = request.extensions.get("sni_hostname", None) timeout = timeouts.get("connect", None) + overall_timeout = OverallTimeoutHandler(timeouts) + retries_left = self._retries delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) @@ -115,11 +119,12 @@ def _connect(self, request: Request) -> NetworkStream: "host": self._origin.host.decode("ascii"), "port": self._origin.port, "local_address": self._local_address, - "timeout": timeout, + "timeout": overall_timeout.get_minimum_timeout(timeout), "socket_options": self._socket_options, } with Trace("connect_tcp", logger, request, kwargs) as trace: - stream = self._network_backend.connect_tcp(**kwargs) + with overall_timeout: + stream = self._network_backend.connect_tcp(**kwargs) trace.return_value = stream else: kwargs = { diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 01bec59e..7f2b36e2 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -3,6 +3,8 @@ from types import TracebackType from typing import Iterable, Iterator, Iterable, List, Optional, Type +from httpcore._utils import OverallTimeoutHandler + from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol @@ -174,6 +176,7 @@ def handle_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("pool", None) + overall_timeout = OverallTimeoutHandler(timeouts) with self._optional_thread_lock: # Add the incoming request to our request queue. @@ -188,8 +191,11 @@ def handle_request(self, request: Request) -> Response: closing = self._assign_requests_to_connections() self._close_connections(closing) - # Wait until this request has an assigned connection. - connection = pool_request.wait_for_connection(timeout=timeout) + with overall_timeout: + # Wait until this request has an assigned connection. + connection = pool_request.wait_for_connection( + timeout=overall_timeout.get_minimum_timeout(timeout) + ) try: # Send the request on the assigned connection. diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index a74ff8e8..07d52f9b 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -16,6 +16,8 @@ import h11 +from httpcore._utils import OverallTimeoutHandler + from .._backends.base import NetworkStream from .._exceptions import ( ConnectionNotAvailable, @@ -147,6 +149,7 @@ def handle_request(self, request: Request) -> Response: def _send_request_headers(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) + overall_timeout = OverallTimeoutHandler(timeouts) with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): event = h11.Request( @@ -154,18 +157,29 @@ def _send_request_headers(self, request: Request) -> None: target=request.url.target, headers=request.headers, ) - self._send_event(event, timeout=timeout) + with overall_timeout: + self._send_event( + event, timeout=overall_timeout.get_minimum_timeout(timeout) + ) def _send_request_body(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) + overall_timeout = OverallTimeoutHandler(timeouts) assert isinstance(request.stream, Iterable) for chunk in request.stream: event = h11.Data(data=chunk) - self._send_event(event, timeout=timeout) - self._send_event(h11.EndOfMessage(), timeout=timeout) + with overall_timeout: + self._send_event( + event, timeout=overall_timeout.get_minimum_timeout(timeout) + ) + + with overall_timeout: + self._send_event( + h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout) + ) def _send_event( self, event: h11.Event, timeout: Optional[float] = None @@ -181,9 +195,13 @@ def _receive_response_headers( ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) + overall_timeout = OverallTimeoutHandler(timeouts) while True: - event = self._receive_event(timeout=timeout) + with overall_timeout: + event = self._receive_event( + timeout=overall_timeout.get_minimum_timeout(timeout) + ) if isinstance(event, h11.Response): break if ( @@ -205,9 +223,12 @@ def _receive_response_headers( def _receive_response_body(self, request: Request) -> Iterator[bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) + overall_timeout = OverallTimeoutHandler(timeouts) while True: - event = self._receive_event(timeout=timeout) + event = self._receive_event( + timeout=overall_timeout.get_minimum_timeout(timeout) + ) if isinstance(event, h11.Data): yield bytes(event.data) elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 1ee4bbb3..7619d4f0 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -10,6 +10,8 @@ import h2.exceptions import h2.settings +from httpcore._utils import OverallTimeoutHandler + from .._backends.base import NetworkStream from .._exceptions import ( ConnectionNotAvailable, @@ -430,12 +432,16 @@ def _read_incoming_data( ) -> typing.List[h2.events.Event]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) + overall_timeout = OverallTimeoutHandler(timeouts) if self._read_exception is not None: raise self._read_exception # pragma: nocover try: - data = self._network_stream.read(self.READ_NUM_BYTES, timeout) + with overall_timeout: + data = self._network_stream.read( + self.READ_NUM_BYTES, overall_timeout.get_minimum_timeout(timeout) + ) if data == b"": raise RemoteProtocolError("Server disconnected") except Exception as exc: @@ -458,6 +464,7 @@ def _read_incoming_data( def _write_outgoing_data(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) + overall_timeout = OverallTimeoutHandler(timeouts) with self._write_lock: data_to_send = self._h2_state.data_to_send() @@ -466,7 +473,8 @@ def _write_outgoing_data(self, request: Request) -> None: raise self._write_exception # pragma: nocover try: - self._network_stream.write(data_to_send, timeout) + with overall_timeout: + self._network_stream.write(data_to_send, timeout) except Exception as exc: # pragma: nocover # If we get a network error we should: # diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index 6acac9a7..6b9b38fc 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -3,6 +3,8 @@ from base64 import b64encode from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from httpcore._utils import OverallTimeoutHandler + from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ProxyError from .._models import ( @@ -266,6 +268,7 @@ def __init__( def handle_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("connect", None) + overall_timeout = OverallTimeoutHandler(timeouts) with self._connect_lock: if not self._connected: @@ -311,10 +314,11 @@ def handle_request(self, request: Request) -> Response: kwargs = { "ssl_context": ssl_context, "server_hostname": self._remote_origin.host.decode("ascii"), - "timeout": timeout, + "timeout": overall_timeout.get_minimum_timeout(timeout), } with Trace("start_tls", logger, request, kwargs) as trace: - stream = stream.start_tls(**kwargs) + with overall_timeout: + stream = stream.start_tls(**kwargs) trace.return_value = stream # Determine if we should be using HTTP/1.1 or HTTP/2 diff --git a/httpcore/_sync/socks_proxy.py b/httpcore/_sync/socks_proxy.py index 502e4d7f..55a1f7c7 100644 --- a/httpcore/_sync/socks_proxy.py +++ b/httpcore/_sync/socks_proxy.py @@ -4,6 +4,8 @@ from socksio import socks5 +from httpcore._utils import OverallTimeoutHandler + from .._backends.sync import SyncBackend from .._backends.base import NetworkBackend, NetworkStream from .._exceptions import ConnectionNotAvailable, ProxyError @@ -218,6 +220,7 @@ def handle_request(self, request: Request) -> Response: timeouts = request.extensions.get("timeout", {}) sni_hostname = request.extensions.get("sni_hostname", None) timeout = timeouts.get("connect", None) + overall_timeout = OverallTimeoutHandler(timeouts) with self._connect_lock: if self._connection is None: @@ -226,10 +229,11 @@ def handle_request(self, request: Request) -> Response: kwargs = { "host": self._proxy_origin.host.decode("ascii"), "port": self._proxy_origin.port, - "timeout": timeout, + "timeout": overall_timeout.get_minimum_timeout(timeout), } with Trace("connect_tcp", logger, request, kwargs) as trace: - stream = self._network_backend.connect_tcp(**kwargs) + with overall_timeout: + stream = self._network_backend.connect_tcp(**kwargs) trace.return_value = stream # Connect to the remote host using socks5 diff --git a/httpcore/_utils.py b/httpcore/_utils.py index df5dea8f..e7116868 100644 --- a/httpcore/_utils.py +++ b/httpcore/_utils.py @@ -1,9 +1,36 @@ import select import socket import sys +import time +import types import typing +class OverallTimeoutHandler: + def __init__(self, timeouts: typing.Dict[str, typing.Any]) -> None: + self.timeouts = timeouts + + def __enter__(self) -> None: + self.start_time = time.monotonic() + + def __exit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc_value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> None: + elapsed_time = time.monotonic() - self.start_time + if self.timeouts.get("total") is not None: + self.timeouts["total"] -= elapsed_time + + def get_minimum_timeout(self, timeout: typing.Optional[float]) -> typing.Any: + if self.timeouts.get("total") is None: + return timeout + if timeout is None: + return self.timeouts["total"] + return min(timeout, self.timeouts["total"]) # pragma: nocover + + def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool: """ Return whether a socket, as identifed by its file descriptor, is readable. diff --git a/tests/test_api.py b/tests/test_api.py index b29cf1a0..0daf7b5a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -24,4 +24,6 @@ def test_request_with_content(httpbin): def test_total_timeout(httpbin): with pytest.raises(Exception): - httpcore.request("GET", httpbin.url + "/delay/1", timeout=0.1) + httpcore.request( + "GET", httpbin.url + "/delay/1", extensions={"timeout": {"total": 0.1}} + )