Skip to content

Commit

Permalink
Handle total timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
karpetrosyan committed Jul 13, 2024
1 parent 6d8116b commit f97c976
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 31 deletions.
9 changes: 7 additions & 2 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectError, ConnectTimeout
Expand Down Expand Up @@ -105,6 +107,8 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)

overall_timeout = OverallTimeoutHandler(timeouts)

retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

Expand All @@ -115,11 +119,12 @@ async def _connect(self, request: Request) -> AsyncNetworkStream:
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
"socket_options": self._socket_options,
}
async with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
with overall_timeout:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
Expand Down
10 changes: 8 additions & 2 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import TracebackType
from typing import AsyncIterable, AsyncIterator, Iterable, List, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
Expand Down Expand Up @@ -174,6 +176,7 @@ async def handle_async_request(self, request: Request) -> Response:

timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with self._optional_thread_lock:
# Add the incoming request to our request queue.
Expand All @@ -188,8 +191,11 @@ async def handle_async_request(self, request: Request) -> Response:
closing = self._assign_requests_to_connections()
await self._close_connections(closing)

# Wait until this request has an assigned connection.
connection = await pool_request.wait_for_connection(timeout=timeout)
with overall_timeout:
# Wait until this request has an assigned connection.
connection = await pool_request.wait_for_connection(
timeout=overall_timeout.get_minimum_timeout(timeout)
)

try:
# Send the request on the assigned connection.
Expand Down
31 changes: 26 additions & 5 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import h11

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import AsyncNetworkStream
from .._exceptions import (
ConnectionNotAvailable,
Expand Down Expand Up @@ -147,25 +149,37 @@ async def handle_async_request(self, request: Request) -> Response:
async def _send_request_headers(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
event = h11.Request(
method=request.method,
target=request.url.target,
headers=request.headers,
)
await self._send_event(event, timeout=timeout)
with overall_timeout:
await self._send_event(
event, timeout=overall_timeout.get_minimum_timeout(timeout)
)

async def _send_request_body(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

assert isinstance(request.stream, AsyncIterable)
async for chunk in request.stream:
event = h11.Data(data=chunk)
await self._send_event(event, timeout=timeout)

await self._send_event(h11.EndOfMessage(), timeout=timeout)
with overall_timeout:
await self._send_event(
event, timeout=overall_timeout.get_minimum_timeout(timeout)
)

with overall_timeout:
await self._send_event(
h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout)
)

async def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
Expand All @@ -181,9 +195,13 @@ async def _receive_response_headers(
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

while True:
event = await self._receive_event(timeout=timeout)
with overall_timeout:
event = await self._receive_event(
timeout=overall_timeout.get_minimum_timeout(timeout)
)
if isinstance(event, h11.Response):
break
if (
Expand All @@ -205,9 +223,12 @@ async def _receive_response_headers(
async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

while True:
event = await self._receive_event(timeout=timeout)
event = await self._receive_event(
timeout=overall_timeout.get_minimum_timeout(timeout)
)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
Expand Down
12 changes: 10 additions & 2 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import h2.exceptions
import h2.settings

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import AsyncNetworkStream
from .._exceptions import (
ConnectionNotAvailable,
Expand Down Expand Up @@ -430,12 +432,16 @@ async def _read_incoming_data(
) -> typing.List[h2.events.Event]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

if self._read_exception is not None:
raise self._read_exception # pragma: nocover

try:
data = await self._network_stream.read(self.READ_NUM_BYTES, timeout)
with overall_timeout:
data = await self._network_stream.read(
self.READ_NUM_BYTES, overall_timeout.get_minimum_timeout(timeout)
)
if data == b"":
raise RemoteProtocolError("Server disconnected")
except Exception as exc:
Expand All @@ -458,6 +464,7 @@ async def _read_incoming_data(
async def _write_outgoing_data(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

async with self._write_lock:
data_to_send = self._h2_state.data_to_send()
Expand All @@ -466,7 +473,8 @@ async def _write_outgoing_data(self, request: Request) -> None:
raise self._write_exception # pragma: nocover

try:
await self._network_stream.write(data_to_send, timeout)
with overall_timeout:
await self._network_stream.write(data_to_send, timeout)
except Exception as exc: # pragma: nocover
# If we get a network error we should:
#
Expand Down
8 changes: 6 additions & 2 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from base64 import b64encode
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple, Union

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ProxyError
from .._models import (
Expand Down Expand Up @@ -266,6 +268,7 @@ def __init__(
async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("connect", None)
overall_timeout = OverallTimeoutHandler(timeouts)

async with self._connect_lock:
if not self._connected:
Expand Down Expand Up @@ -311,10 +314,11 @@ async def handle_async_request(self, request: Request) -> Response:
kwargs = {
"ssl_context": ssl_context,
"server_hostname": self._remote_origin.host.decode("ascii"),
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
}
async with Trace("start_tls", logger, request, kwargs) as trace:
stream = await stream.start_tls(**kwargs)
with overall_timeout:
stream = await stream.start_tls(**kwargs)
trace.return_value = stream

# Determine if we should be using HTTP/1.1 or HTTP/2
Expand Down
8 changes: 6 additions & 2 deletions httpcore/_async/socks_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from socksio import socks5

from httpcore._utils import OverallTimeoutHandler

from .._backends.auto import AutoBackend
from .._backends.base import AsyncNetworkBackend, AsyncNetworkStream
from .._exceptions import ConnectionNotAvailable, ProxyError
Expand Down Expand Up @@ -218,6 +220,7 @@ async def handle_async_request(self, request: Request) -> Response:
timeouts = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)
overall_timeout = OverallTimeoutHandler(timeouts)

async with self._connect_lock:
if self._connection is None:
Expand All @@ -226,10 +229,11 @@ async def handle_async_request(self, request: Request) -> Response:
kwargs = {
"host": self._proxy_origin.host.decode("ascii"),
"port": self._proxy_origin.port,
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
}
async with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = await self._network_backend.connect_tcp(**kwargs)
with overall_timeout:
stream = await self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream

# Connect to the remote host using socks5
Expand Down
9 changes: 7 additions & 2 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from types import TracebackType
from typing import Iterable, Iterator, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend, NetworkStream
from .._exceptions import ConnectError, ConnectTimeout
Expand Down Expand Up @@ -105,6 +107,8 @@ def _connect(self, request: Request) -> NetworkStream:
sni_hostname = request.extensions.get("sni_hostname", None)
timeout = timeouts.get("connect", None)

overall_timeout = OverallTimeoutHandler(timeouts)

retries_left = self._retries
delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

Expand All @@ -115,11 +119,12 @@ def _connect(self, request: Request) -> NetworkStream:
"host": self._origin.host.decode("ascii"),
"port": self._origin.port,
"local_address": self._local_address,
"timeout": timeout,
"timeout": overall_timeout.get_minimum_timeout(timeout),
"socket_options": self._socket_options,
}
with Trace("connect_tcp", logger, request, kwargs) as trace:
stream = self._network_backend.connect_tcp(**kwargs)
with overall_timeout:
stream = self._network_backend.connect_tcp(**kwargs)
trace.return_value = stream
else:
kwargs = {
Expand Down
10 changes: 8 additions & 2 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from types import TracebackType
from typing import Iterable, Iterator, Iterable, List, Optional, Type

from httpcore._utils import OverallTimeoutHandler

from .._backends.sync import SyncBackend
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
Expand Down Expand Up @@ -174,6 +176,7 @@ def handle_request(self, request: Request) -> Response:

timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("pool", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with self._optional_thread_lock:
# Add the incoming request to our request queue.
Expand All @@ -188,8 +191,11 @@ def handle_request(self, request: Request) -> Response:
closing = self._assign_requests_to_connections()
self._close_connections(closing)

# Wait until this request has an assigned connection.
connection = pool_request.wait_for_connection(timeout=timeout)
with overall_timeout:
# Wait until this request has an assigned connection.
connection = pool_request.wait_for_connection(
timeout=overall_timeout.get_minimum_timeout(timeout)
)

try:
# Send the request on the assigned connection.
Expand Down
31 changes: 26 additions & 5 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import h11

from httpcore._utils import OverallTimeoutHandler

from .._backends.base import NetworkStream
from .._exceptions import (
ConnectionNotAvailable,
Expand Down Expand Up @@ -147,25 +149,37 @@ def handle_request(self, request: Request) -> Response:
def _send_request_headers(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
event = h11.Request(
method=request.method,
target=request.url.target,
headers=request.headers,
)
self._send_event(event, timeout=timeout)
with overall_timeout:
self._send_event(
event, timeout=overall_timeout.get_minimum_timeout(timeout)
)

def _send_request_body(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)
overall_timeout = OverallTimeoutHandler(timeouts)

assert isinstance(request.stream, Iterable)
for chunk in request.stream:
event = h11.Data(data=chunk)
self._send_event(event, timeout=timeout)

self._send_event(h11.EndOfMessage(), timeout=timeout)
with overall_timeout:
self._send_event(
event, timeout=overall_timeout.get_minimum_timeout(timeout)
)

with overall_timeout:
self._send_event(
h11.EndOfMessage(), timeout=overall_timeout.get_minimum_timeout(timeout)
)

def _send_event(
self, event: h11.Event, timeout: Optional[float] = None
Expand All @@ -181,9 +195,13 @@ def _receive_response_headers(
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

while True:
event = self._receive_event(timeout=timeout)
with overall_timeout:
event = self._receive_event(
timeout=overall_timeout.get_minimum_timeout(timeout)
)
if isinstance(event, h11.Response):
break
if (
Expand All @@ -205,9 +223,12 @@ def _receive_response_headers(
def _receive_response_body(self, request: Request) -> Iterator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)
overall_timeout = OverallTimeoutHandler(timeouts)

while True:
event = self._receive_event(timeout=timeout)
event = self._receive_event(
timeout=overall_timeout.get_minimum_timeout(timeout)
)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
Expand Down
Loading

0 comments on commit f97c976

Please sign in to comment.