From cc27a659b8877880f5e2d40628ff7dc7d10d5160 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 3 Oct 2020 12:14:41 +0200 Subject: [PATCH 01/17] Turn transport.request() into a context manager --- README.md | 18 +-- docs/index.md | 16 +- httpcore/_async/base.py | 10 +- httpcore/_async/connection.py | 11 +- httpcore/_async/connection_pool.py | 109 ++++++------- httpcore/_async/http11.py | 9 +- httpcore/_async/http2.py | 18 ++- httpcore/_async/http_proxy.py | 177 ++++++++++++---------- httpcore/_compat.py | 9 ++ httpcore/_sync/base.py | 8 +- httpcore/_sync/connection.py | 11 +- httpcore/_sync/connection_pool.py | 109 ++++++------- httpcore/_sync/http11.py | 9 +- httpcore/_sync/http2.py | 18 ++- httpcore/_sync/http_proxy.py | 177 ++++++++++++---------- setup.py | 8 +- tests/async_tests/test_connection_pool.py | 123 +++++++++------ tests/async_tests/test_interfaces.py | 137 +++++++++-------- tests/sync_tests/test_connection_pool.py | 123 +++++++++------ tests/sync_tests/test_interfaces.py | 137 +++++++++-------- unasync.py | 5 + 21 files changed, 668 insertions(+), 574 deletions(-) create mode 100644 httpcore/_compat.py diff --git a/README.md b/README.md index 5178ada6..342863c0 100644 --- a/README.md +++ b/README.md @@ -43,16 +43,13 @@ Here's an example of making an HTTP GET request using `httpcore`... ```python with httpcore.SyncConnectionPool() as http: - status_code, headers, stream, ext = http.request( + with http.request( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as response: + status_code, headers, stream, ext = respnose body = b''.join([chunk for chunk in stream]) - finally: - stream.close() print(status_code, body) ``` @@ -61,16 +58,13 @@ Or, using async... ```python async with httpcore.AsyncConnectionPool() as http: - status_code, headers, stream, ext = await http.arequest( + async with http.arequest( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk async for chunk in stream]) - finally: - await stream.aclose() print(status_code, body) ``` diff --git a/docs/index.md b/docs/index.md index 5178ada6..7eb222e7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -43,16 +43,12 @@ Here's an example of making an HTTP GET request using `httpcore`... ```python with httpcore.SyncConnectionPool() as http: - status_code, headers, stream, ext = http.request( + with http.request( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as (status_code, headers, stream, ext): body = b''.join([chunk for chunk in stream]) - finally: - stream.close() print(status_code, body) ``` @@ -61,16 +57,12 @@ Or, using async... ```python async with httpcore.AsyncConnectionPool() as http: - status_code, headers, stream, ext = await http.arequest( + async with http.arequest( method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) - - try: + ) as (status_code, headers, stream, ext): body = b''.join([chunk async for chunk in stream]) - finally: - await stream.aclose() print(status_code, body) ``` diff --git a/httpcore/_async/base.py b/httpcore/_async/base.py index cf449f42..3d6be7a2 100644 --- a/httpcore/_async/base.py +++ b/httpcore/_async/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import AsyncIterator, Tuple, Type +from typing import AsyncContextManager, AsyncIterator, Tuple, Type from .._types import URL, Headers, T @@ -57,18 +57,18 @@ class AsyncHTTPTransport: """ The base interface for sending HTTP requests. - Concete implementations should subclass this class, and implement + Concrete implementations should subclass this class, and implement the `request` method, and optionally the `close` method. """ - async def arequest( + def arequest( self, method: bytes, url: URL, headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncContextManager[Tuple[int, Headers, AsyncByteStream, dict]]: """ The interface for sending a single HTTP request, and returning a response. @@ -84,7 +84,7 @@ async def arequest( ** Returns:** - A four-tuple of: + A context manager yielding a four-tuple of: * **status_code** - `int` - The HTTP status code, such as `200`. * **headers** - `List[Tuple[bytes, bytes]]` - Any HTTP headers included diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 258d20d5..9578e246 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -1,7 +1,8 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import AsyncIterator, Optional, Tuple, cast from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend +from .._compat import asynccontextmanager from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, url_to_origin from .base import ( @@ -66,6 +67,7 @@ def request_lock(self) -> AsyncLock: self._request_lock = self.backend.create_lock() return self._request_lock + @asynccontextmanager async def arequest( self, method: bytes, @@ -73,7 +75,7 @@ async def arequest( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -97,7 +99,10 @@ async def arequest( logger.trace( "connection.arequest method=%r url=%r headers=%r", method, url, headers ) - return await self.connection.arequest(method, url, headers, stream, ext) + async with self.connection.arequest( + method, url, headers, stream, ext + ) as response: + yield response async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: scheme, hostname, port = self.origin diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 7bbc3ff3..82650a06 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -1,9 +1,11 @@ import warnings +from functools import partial from ssl import SSLContext -from typing import AsyncIterator, Callable, Dict, List, Optional, Set, Tuple, cast +from typing import AsyncIterator, Dict, List, Optional, Set, Tuple, cast from .._backends.auto import AsyncLock, AsyncSemaphore from .._backends.base import lookup_async_backend +from .._compat import AsyncExitStack, asynccontextmanager from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict @@ -30,39 +32,6 @@ async def release(self) -> None: return -class ResponseByteStream(AsyncByteStream): - def __init__( - self, - stream: AsyncByteStream, - connection: AsyncHTTPConnection, - callback: Callable, - ) -> None: - """ - A wrapper around the response stream that we return from `.arequest()`. - - Ensures that when `stream.aclose()` is called, the connection pool - is notified via a callback. - """ - self.stream = stream - self.connection = connection - self.callback = callback - - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self.stream: - yield chunk - - async def aclose(self) -> None: - try: - # Call the underlying stream close callback. - # This will be a call to `AsyncHTTP11Connection._response_closed()` - # or `AsyncHTTP2Stream._response_closed()`. - await self.stream.aclose() - finally: - # Call the connection pool close callback. - # This will be a call to `AsyncConnectionPool._response_closed()`. - await self.callback(self.connection) - - class AsyncConnectionPool(AsyncHTTPTransport): """ A connection pool for making HTTP requests. @@ -160,6 +129,7 @@ def _create_connection( backend=self._backend, ) + @asynccontextmanager async def arequest( self, method: bytes, @@ -167,7 +137,7 @@ async def arequest( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") @@ -180,38 +150,45 @@ async def arequest( await self._keepalive_sweep() - connection: Optional[AsyncHTTPConnection] = None - while connection is None: - async with self._connection_acquiry_lock: - # We get-or-create a connection as an atomic operation, to ensure - # that HTTP/2 requests issued in close concurrency will end up - # on the same connection. - logger.trace("get_connection_from_pool=%r", origin) - connection = await self._get_connection_from_pool(origin) - - if connection is None: - connection = self._create_connection(origin=origin) - logger.trace("created connection=%r", connection) - await self._add_to_pool(connection, timeout=timeout) - else: - logger.trace("reuse connection=%r", connection) - - try: - response = await connection.arequest( - method, url, headers=headers, stream=stream, ext=ext - ) - except NewConnectionRequired: - connection = None - except Exception: # noqa: PIE786 - logger.trace("remove from pool connection=%r", connection) - await self._remove_from_pool(connection) - raise + async with AsyncExitStack() as exit_stack: + connection: Optional[AsyncHTTPConnection] = None + while connection is None: + async with self._connection_acquiry_lock: + # We get-or-create a connection as an atomic operation, to ensure + # that HTTP/2 requests issued in close concurrency will end up + # on the same connection. + logger.trace("get_connection_from_pool=%r", origin) + connection = await self._get_connection_from_pool(origin) + + if connection is None: + connection = self._create_connection(origin=origin) + logger.trace("created connection=%r", connection) + await self._add_to_pool(connection, timeout=timeout) + else: + logger.trace("reuse connection=%r", connection) + + try: + # Push this callback onto the stack *before* making the request, + # so that it's effectively executed *after* the response is closed. + exit_stack.push_async_callback( + partial(self._response_closed, connection) + ) + + response = await exit_stack.enter_async_context( + connection.arequest( + method, url, headers=headers, stream=stream, ext=ext + ) + ) + except NewConnectionRequired: + exit_stack.pop_all() # Drop any registered callbacks. + connection = None + except Exception: # noqa: PIE786 + logger.trace("remove from pool connection=%r", connection) + exit_stack.pop_all() # Drop any registered callbacks. + await self._remove_from_pool(connection) + raise - status_code, headers, stream, ext = response - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, ext + yield response async def _get_connection_from_pool( self, origin: Origin diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 2e0e378d..b715798e 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -5,6 +5,7 @@ from .._backends.auto import AsyncSocketStream from .._bytestreams import AsyncIteratorByteStream, PlainByteStream +from .._compat import asynccontextmanager from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions from .._types import URL, Headers, TimeoutDict from .._utils import get_logger @@ -47,6 +48,7 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @asynccontextmanager async def arequest( self, method: bytes, @@ -54,7 +56,7 @@ async def arequest( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: headers = [] if headers is None else headers stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -78,7 +80,10 @@ async def arequest( "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - return (status_code, headers, response_stream, ext) + try: + yield (status_code, headers, response_stream, ext) + finally: + await response_stream.aclose() async def start_tls( self, hostname: bytes, timeout: TimeoutDict = None diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 6dd84f1d..42b9bad0 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -9,6 +9,7 @@ from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream from .._bytestreams import AsyncIteratorByteStream, PlainByteStream +from .._compat import asynccontextmanager from .._exceptions import PoolTimeout, RemoteProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger @@ -85,6 +86,7 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @asynccontextmanager async def arequest( self, method: bytes, @@ -92,7 +94,7 @@ async def arequest( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -116,7 +118,10 @@ async def arequest( h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self) self.streams[stream_id] = h2_stream self.events[stream_id] = [] - return await h2_stream.arequest(method, url, headers, stream, ext) + async with h2_stream.arequest( + method, url, headers, stream, ext + ) as response: + yield response except Exception: # noqa: PIE786 await self.max_streams_semaphore.release() raise @@ -270,6 +275,7 @@ def __init__(self, stream_id: int, connection: AsyncHTTP2Connection) -> None: self.stream_id = stream_id self.connection = connection + @asynccontextmanager async def arequest( self, method: bytes, @@ -277,7 +283,7 @@ async def arequest( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: headers = [] if headers is None else [(k.lower(), v) for (k, v) in headers] stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -302,7 +308,11 @@ async def arequest( ext = { "http_version": "HTTP/2", } - return (status_code, headers, response_stream, ext) + + try: + yield (status_code, headers, response_stream, ext) + finally: + await response_stream.aclose() async def send_headers( self, diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index 8a9f33c2..f1ee8487 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -1,13 +1,14 @@ from http import HTTPStatus from ssl import SSLContext -from typing import Tuple, cast +from typing import AsyncIterator, Tuple, cast +from .._compat import AsyncExitStack, asynccontextmanager from .._exceptions import ProxyError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger, url_to_origin from .base import AsyncByteStream from .connection import AsyncHTTPConnection -from .connection_pool import AsyncConnectionPool, ResponseByteStream +from .connection_pool import AsyncConnectionPool logger = get_logger(__name__) @@ -87,6 +88,7 @@ def __init__( max_keepalive=max_keepalive, ) + @asynccontextmanager async def arequest( self, method: bytes, @@ -94,7 +96,7 @@ async def arequest( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: if self._keepalive_expiry is not None: await self._keepalive_sweep() @@ -109,9 +111,10 @@ async def arequest( method, url, ) - return await self._forward_request( + async with self._forward_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response else: # By default HTTPS should be tunnelled. logger.trace( @@ -121,10 +124,12 @@ async def arequest( method, url, ) - return await self._tunnel_request( + async with self._tunnel_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response + @asynccontextmanager async def _forward_request( self, method: bytes, @@ -132,7 +137,7 @@ async def _forward_request( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: """ Forwarded proxy requests include the entire URL as the HTTP target, rather than just the path. @@ -162,16 +167,15 @@ async def _forward_request( url = self.proxy_origin + (target,) headers = merge_headers(self.proxy_headers, headers) - (status_code, headers, stream, ext) = await connection.arequest( + async with connection.arequest( method, url, headers=headers, stream=stream, ext=ext - ) - - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - - return status_code, headers, wrapped_stream, ext + ) as response: + try: + yield response + finally: + await self._response_closed(connection) + @asynccontextmanager async def _tunnel_request( self, method: bytes, @@ -179,7 +183,7 @@ async def _tunnel_request( headers: Headers = None, stream: AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: """ Tunnelled proxy requests require an initial CONNECT request to establish the connection, and then send regular requests. @@ -189,72 +193,77 @@ async def _tunnel_request( origin = url_to_origin(url) connection = await self._get_connection_from_pool(origin) - if connection is None: - scheme, host, port = origin - - # First, create a connection to the proxy server - proxy_connection = AsyncHTTPConnection( - origin=self.proxy_origin, - http2=self._http2, - ssl_context=self._ssl_context, - ) - - # Issue a CONNECT request... - - # CONNECT www.example.org:80 HTTP/1.1 - # [proxy-headers] - target = b"%b:%d" % (host, port) - connect_url = self.proxy_origin + (target,) - connect_headers = [(b"Host", target), (b"Accept", b"*/*")] - connect_headers = merge_headers(connect_headers, self.proxy_headers) - (proxy_status_code, _, proxy_stream, _) = await proxy_connection.arequest( - b"CONNECT", connect_url, headers=connect_headers, ext=ext - ) - - proxy_reason = get_reason_phrase(proxy_status_code) - logger.trace( - "tunnel_response proxy_status_code=%r proxy_reason=%r ", - proxy_status_code, - proxy_reason, - ) - # Read the response data without closing the socket - async for _ in proxy_stream: - pass - - # See if the tunnel was successfully established. - if proxy_status_code < 200 or proxy_status_code > 299: - msg = "%d %s" % (proxy_status_code, proxy_reason) - raise ProxyError(msg) - - # Upgrade to TLS if required - # We assume the target speaks TLS on the specified port - if scheme == b"https": - await proxy_connection.start_tls(host, timeout) - - # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. - # This means the proxy connection is now unusable, and we must create - # a new one for regular requests, making sure to use the same socket to - # retain the tunnel. - connection = AsyncHTTPConnection( - origin=origin, - http2=self._http2, - ssl_context=self._ssl_context, - socket=proxy_connection.socket, + async with AsyncExitStack() as exit_stack: + if connection is None: + scheme, host, port = origin + + # First, create a connection to the proxy server + proxy_connection = AsyncHTTPConnection( + origin=self.proxy_origin, + http2=self._http2, + ssl_context=self._ssl_context, + ) + + # Issue a CONNECT request... + + # CONNECT www.example.org:80 HTTP/1.1 + # [proxy-headers] + target = b"%b:%d" % (host, port) + connect_url = self.proxy_origin + (target,) + connect_headers = [(b"Host", target), (b"Accept", b"*/*")] + connect_headers = merge_headers(connect_headers, self.proxy_headers) + + proxy_response = await exit_stack.enter_async_context( + proxy_connection.arequest( + b"CONNECT", connect_url, headers=connect_headers, ext=ext + ) + ) + proxy_status_code, _, proxy_stream, _ = proxy_response + proxy_reason = get_reason_phrase(proxy_status_code) + logger.trace( + "tunnel_response proxy_status_code=%r proxy_reason=%r ", + proxy_status_code, + proxy_reason, + ) + # Read the response data without closing the socket + async for _ in proxy_stream: + pass + + # See if the tunnel was successfully established. + if proxy_status_code < 200 or proxy_status_code > 299: + msg = "%d %s" % (proxy_status_code, proxy_reason) + raise ProxyError(msg) + + # Upgrade to TLS if required + # We assume the target speaks TLS on the specified port + if scheme == b"https": + await proxy_connection.start_tls(host, timeout) + + # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. + # This means the proxy connection is now unusable, and we must create + # a new one for regular requests, making sure to use the same socket to + # retain the tunnel. + connection = AsyncHTTPConnection( + origin=origin, + http2=self._http2, + ssl_context=self._ssl_context, + socket=proxy_connection.socket, + ) + await self._add_to_pool(connection, timeout) + + # Once the connection has been established we can send requests on + # it as normal. + response = await exit_stack.enter_async_context( + connection.arequest( + method, + url, + headers=headers, + stream=stream, + ext=ext, + ) ) - await self._add_to_pool(connection, timeout) - - # Once the connection has been established we can send requests on - # it as normal. - (status_code, headers, stream, ext) = await connection.arequest( - method, - url, - headers=headers, - stream=stream, - ext=ext, - ) - - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, ext + try: + yield response + finally: + await self._response_closed(connection) diff --git a/httpcore/_compat.py b/httpcore/_compat.py new file mode 100644 index 00000000..f62aa340 --- /dev/null +++ b/httpcore/_compat.py @@ -0,0 +1,9 @@ +try: + from contextlib import AsyncExitStack, asynccontextmanager +except ImportError: # pragma: no cover + # Python 3.6 + from async_exit_stack import AsyncExitStack # type: ignore # noqa: F401 + from async_generator import asynccontextmanager # type: ignore # noqa: F401 + +# These will be imported by the unasynced code. +from contextlib import ExitStack, contextmanager # noqa: F401 diff --git a/httpcore/_sync/base.py b/httpcore/_sync/base.py index 95a434eb..519735eb 100644 --- a/httpcore/_sync/base.py +++ b/httpcore/_sync/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import Iterator, Tuple, Type +from typing import ContextManager, Iterator, Tuple, Type from .._types import URL, Headers, T @@ -57,7 +57,7 @@ class SyncHTTPTransport: """ The base interface for sending HTTP requests. - Concete implementations should subclass this class, and implement + Concrete implementations should subclass this class, and implement the `request` method, and optionally the `close` method. """ @@ -68,7 +68,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> ContextManager[Tuple[int, Headers, SyncByteStream, dict]]: """ The interface for sending a single HTTP request, and returning a response. @@ -84,7 +84,7 @@ def request( ** Returns:** - A four-tuple of: + A context manager yielding a four-tuple of: * **status_code** - `int` - The HTTP status code, such as `200`. * **headers** - `List[Tuple[bytes, bytes]]` - Any HTTP headers included diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 480acb47..ea469c6c 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -1,7 +1,8 @@ from ssl import SSLContext -from typing import Optional, Tuple, cast +from typing import Iterator, Optional, Tuple, cast from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend +from .._compat import contextmanager from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, url_to_origin from .base import ( @@ -66,6 +67,7 @@ def request_lock(self) -> SyncLock: self._request_lock = self.backend.create_lock() return self._request_lock + @contextmanager def request( self, method: bytes, @@ -73,7 +75,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -97,7 +99,10 @@ def request( logger.trace( "connection.request method=%r url=%r headers=%r", method, url, headers ) - return self.connection.request(method, url, headers, stream, ext) + with self.connection.request( + method, url, headers, stream, ext + ) as response: + yield response def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: scheme, hostname, port = self.origin diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 91af75e5..3f5f566d 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -1,9 +1,11 @@ import warnings +from functools import partial from ssl import SSLContext -from typing import Iterator, Callable, Dict, List, Optional, Set, Tuple, cast +from typing import Iterator, Dict, List, Optional, Set, Tuple, cast from .._backends.sync import SyncLock, SyncSemaphore from .._backends.base import lookup_sync_backend +from .._compat import ExitStack, contextmanager from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict @@ -30,39 +32,6 @@ def release(self) -> None: return -class ResponseByteStream(SyncByteStream): - def __init__( - self, - stream: SyncByteStream, - connection: SyncHTTPConnection, - callback: Callable, - ) -> None: - """ - A wrapper around the response stream that we return from `.request()`. - - Ensures that when `stream.close()` is called, the connection pool - is notified via a callback. - """ - self.stream = stream - self.connection = connection - self.callback = callback - - def __iter__(self) -> Iterator[bytes]: - for chunk in self.stream: - yield chunk - - def close(self) -> None: - try: - # Call the underlying stream close callback. - # This will be a call to `SyncHTTP11Connection._response_closed()` - # or `SyncHTTP2Stream._response_closed()`. - self.stream.close() - finally: - # Call the connection pool close callback. - # This will be a call to `SyncConnectionPool._response_closed()`. - self.callback(self.connection) - - class SyncConnectionPool(SyncHTTPTransport): """ A connection pool for making HTTP requests. @@ -160,6 +129,7 @@ def _create_connection( backend=self._backend, ) + @contextmanager def request( self, method: bytes, @@ -167,7 +137,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") @@ -180,38 +150,45 @@ def request( self._keepalive_sweep() - connection: Optional[SyncHTTPConnection] = None - while connection is None: - with self._connection_acquiry_lock: - # We get-or-create a connection as an atomic operation, to ensure - # that HTTP/2 requests issued in close concurrency will end up - # on the same connection. - logger.trace("get_connection_from_pool=%r", origin) - connection = self._get_connection_from_pool(origin) - - if connection is None: - connection = self._create_connection(origin=origin) - logger.trace("created connection=%r", connection) - self._add_to_pool(connection, timeout=timeout) - else: - logger.trace("reuse connection=%r", connection) - - try: - response = connection.request( - method, url, headers=headers, stream=stream, ext=ext - ) - except NewConnectionRequired: - connection = None - except Exception: # noqa: PIE786 - logger.trace("remove from pool connection=%r", connection) - self._remove_from_pool(connection) - raise + with ExitStack() as exit_stack: + connection: Optional[SyncHTTPConnection] = None + while connection is None: + with self._connection_acquiry_lock: + # We get-or-create a connection as an atomic operation, to ensure + # that HTTP/2 requests issued in close concurrency will end up + # on the same connection. + logger.trace("get_connection_from_pool=%r", origin) + connection = self._get_connection_from_pool(origin) + + if connection is None: + connection = self._create_connection(origin=origin) + logger.trace("created connection=%r", connection) + self._add_to_pool(connection, timeout=timeout) + else: + logger.trace("reuse connection=%r", connection) + + try: + # Push this callback onto the stack *before* making the request, + # so that it's effectively executed *after* the response is closed. + exit_stack.callback( + partial(self._response_closed, connection) + ) + + response = exit_stack.enter_context( + connection.request( + method, url, headers=headers, stream=stream, ext=ext + ) + ) + except NewConnectionRequired: + exit_stack.pop_all() # Drop any registered callbacks. + connection = None + except Exception: # noqa: PIE786 + logger.trace("remove from pool connection=%r", connection) + exit_stack.pop_all() # Drop any registered callbacks. + self._remove_from_pool(connection) + raise - status_code, headers, stream, ext = response - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, ext + yield response def _get_connection_from_pool( self, origin: Origin diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index 067d6134..3735e341 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -5,6 +5,7 @@ from .._backends.sync import SyncSocketStream from .._bytestreams import IteratorByteStream, PlainByteStream +from .._compat import contextmanager from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions from .._types import URL, Headers, TimeoutDict from .._utils import get_logger @@ -47,6 +48,7 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @contextmanager def request( self, method: bytes, @@ -54,7 +56,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: headers = [] if headers is None else headers stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -78,7 +80,10 @@ def request( "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - return (status_code, headers, response_stream, ext) + try: + yield (status_code, headers, response_stream, ext) + finally: + response_stream.close() def start_tls( self, hostname: bytes, timeout: TimeoutDict = None diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 2d8b8d12..06bb72d6 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -9,6 +9,7 @@ from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore, SyncSocketStream from .._bytestreams import IteratorByteStream, PlainByteStream +from .._compat import contextmanager from .._exceptions import PoolTimeout, RemoteProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger @@ -85,6 +86,7 @@ def mark_as_ready(self) -> None: if self.state == ConnectionState.IDLE: self.state = ConnectionState.READY + @contextmanager def request( self, method: bytes, @@ -92,7 +94,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -116,7 +118,10 @@ def request( h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self) self.streams[stream_id] = h2_stream self.events[stream_id] = [] - return h2_stream.request(method, url, headers, stream, ext) + with h2_stream.request( + method, url, headers, stream, ext + ) as response: + yield response except Exception: # noqa: PIE786 self.max_streams_semaphore.release() raise @@ -270,6 +275,7 @@ def __init__(self, stream_id: int, connection: SyncHTTP2Connection) -> None: self.stream_id = stream_id self.connection = connection + @contextmanager def request( self, method: bytes, @@ -277,7 +283,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: headers = [] if headers is None else [(k.lower(), v) for (k, v) in headers] stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -302,7 +308,11 @@ def request( ext = { "http_version": "HTTP/2", } - return (status_code, headers, response_stream, ext) + + try: + yield (status_code, headers, response_stream, ext) + finally: + response_stream.close() def send_headers( self, diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index aa3a1ae5..01745ce7 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -1,13 +1,14 @@ from http import HTTPStatus from ssl import SSLContext -from typing import Tuple, cast +from typing import Iterator, Tuple, cast +from .._compat import ExitStack, contextmanager from .._exceptions import ProxyError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger, url_to_origin from .base import SyncByteStream from .connection import SyncHTTPConnection -from .connection_pool import SyncConnectionPool, ResponseByteStream +from .connection_pool import SyncConnectionPool logger = get_logger(__name__) @@ -87,6 +88,7 @@ def __init__( max_keepalive=max_keepalive, ) + @contextmanager def request( self, method: bytes, @@ -94,7 +96,7 @@ def request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: if self._keepalive_expiry is not None: self._keepalive_sweep() @@ -109,9 +111,10 @@ def request( method, url, ) - return self._forward_request( + with self._forward_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response else: # By default HTTPS should be tunnelled. logger.trace( @@ -121,10 +124,12 @@ def request( method, url, ) - return self._tunnel_request( + with self._tunnel_request( method, url, headers=headers, stream=stream, ext=ext - ) + ) as response: + yield response + @contextmanager def _forward_request( self, method: bytes, @@ -132,7 +137,7 @@ def _forward_request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: """ Forwarded proxy requests include the entire URL as the HTTP target, rather than just the path. @@ -162,16 +167,15 @@ def _forward_request( url = self.proxy_origin + (target,) headers = merge_headers(self.proxy_headers, headers) - (status_code, headers, stream, ext) = connection.request( + with connection.request( method, url, headers=headers, stream=stream, ext=ext - ) - - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - - return status_code, headers, wrapped_stream, ext + ) as response: + try: + yield response + finally: + self._response_closed(connection) + @contextmanager def _tunnel_request( self, method: bytes, @@ -179,7 +183,7 @@ def _tunnel_request( headers: Headers = None, stream: SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: """ Tunnelled proxy requests require an initial CONNECT request to establish the connection, and then send regular requests. @@ -189,72 +193,77 @@ def _tunnel_request( origin = url_to_origin(url) connection = self._get_connection_from_pool(origin) - if connection is None: - scheme, host, port = origin - - # First, create a connection to the proxy server - proxy_connection = SyncHTTPConnection( - origin=self.proxy_origin, - http2=self._http2, - ssl_context=self._ssl_context, - ) - - # Issue a CONNECT request... - - # CONNECT www.example.org:80 HTTP/1.1 - # [proxy-headers] - target = b"%b:%d" % (host, port) - connect_url = self.proxy_origin + (target,) - connect_headers = [(b"Host", target), (b"Accept", b"*/*")] - connect_headers = merge_headers(connect_headers, self.proxy_headers) - (proxy_status_code, _, proxy_stream, _) = proxy_connection.request( - b"CONNECT", connect_url, headers=connect_headers, ext=ext - ) - - proxy_reason = get_reason_phrase(proxy_status_code) - logger.trace( - "tunnel_response proxy_status_code=%r proxy_reason=%r ", - proxy_status_code, - proxy_reason, - ) - # Read the response data without closing the socket - for _ in proxy_stream: - pass - - # See if the tunnel was successfully established. - if proxy_status_code < 200 or proxy_status_code > 299: - msg = "%d %s" % (proxy_status_code, proxy_reason) - raise ProxyError(msg) - - # Upgrade to TLS if required - # We assume the target speaks TLS on the specified port - if scheme == b"https": - proxy_connection.start_tls(host, timeout) - - # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. - # This means the proxy connection is now unusable, and we must create - # a new one for regular requests, making sure to use the same socket to - # retain the tunnel. - connection = SyncHTTPConnection( - origin=origin, - http2=self._http2, - ssl_context=self._ssl_context, - socket=proxy_connection.socket, + with ExitStack() as exit_stack: + if connection is None: + scheme, host, port = origin + + # First, create a connection to the proxy server + proxy_connection = SyncHTTPConnection( + origin=self.proxy_origin, + http2=self._http2, + ssl_context=self._ssl_context, + ) + + # Issue a CONNECT request... + + # CONNECT www.example.org:80 HTTP/1.1 + # [proxy-headers] + target = b"%b:%d" % (host, port) + connect_url = self.proxy_origin + (target,) + connect_headers = [(b"Host", target), (b"Accept", b"*/*")] + connect_headers = merge_headers(connect_headers, self.proxy_headers) + + proxy_response = exit_stack.enter_context( + proxy_connection.request( + b"CONNECT", connect_url, headers=connect_headers, ext=ext + ) + ) + proxy_status_code, _, proxy_stream, _ = proxy_response + proxy_reason = get_reason_phrase(proxy_status_code) + logger.trace( + "tunnel_response proxy_status_code=%r proxy_reason=%r ", + proxy_status_code, + proxy_reason, + ) + # Read the response data without closing the socket + for _ in proxy_stream: + pass + + # See if the tunnel was successfully established. + if proxy_status_code < 200 or proxy_status_code > 299: + msg = "%d %s" % (proxy_status_code, proxy_reason) + raise ProxyError(msg) + + # Upgrade to TLS if required + # We assume the target speaks TLS on the specified port + if scheme == b"https": + proxy_connection.start_tls(host, timeout) + + # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. + # This means the proxy connection is now unusable, and we must create + # a new one for regular requests, making sure to use the same socket to + # retain the tunnel. + connection = SyncHTTPConnection( + origin=origin, + http2=self._http2, + ssl_context=self._ssl_context, + socket=proxy_connection.socket, + ) + self._add_to_pool(connection, timeout) + + # Once the connection has been established we can send requests on + # it as normal. + response = exit_stack.enter_context( + connection.request( + method, + url, + headers=headers, + stream=stream, + ext=ext, + ) ) - self._add_to_pool(connection, timeout) - - # Once the connection has been established we can send requests on - # it as normal. - (status_code, headers, stream, ext) = connection.request( - method, - url, - headers=headers, - stream=stream, - ext=ext, - ) - - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, ext + try: + yield response + finally: + self._response_closed(connection) diff --git a/setup.py b/setup.py index f91c4d95..c4857215 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,13 @@ def get_packages(package): packages=get_packages("httpcore"), include_package_data=True, zip_safe=False, - install_requires=["h11>=0.8,<0.10", "sniffio==1.*"], + install_requires=[ + "h11>=0.8,<0.10", + "sniffio==1.*", + # Backports. + "async_generator; python_version<'3.7'", + "async-exit-stack; python_version<'3.7'", + ], extras_require={ "http2": ["h2==3.*"], }, diff --git a/tests/async_tests/test_connection_pool.py b/tests/async_tests/test_connection_pool.py index b52f0c5a..dd4039d5 100644 --- a/tests/async_tests/test_connection_pool.py +++ b/tests/async_tests/test_connection_pool.py @@ -1,3 +1,4 @@ +from contextlib import AsyncExitStack, asynccontextmanager from typing import AsyncIterator, Tuple import pytest @@ -7,7 +8,7 @@ from httpcore._types import URL, Headers -class MockConnection(object): +class MockConnection(httpcore.AsyncHTTPTransport): def __init__(self, http_version): self.origin = (b"http", b"example.org", 80) self.state = ConnectionState.PENDING @@ -15,6 +16,7 @@ def __init__(self, http_version): self.is_http2 = http_version == "HTTP/2" self.stream_count = 0 + @asynccontextmanager async def arequest( self, method: bytes, @@ -22,7 +24,7 @@ async def arequest( headers: Headers = None, stream: httpcore.AsyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, httpcore.AsyncByteStream, dict]: + ) -> AsyncIterator[Tuple[int, Headers, httpcore.AsyncByteStream, dict]]: self.state = ConnectionState.ACTIVE self.stream_count += 1 @@ -38,7 +40,10 @@ async def aiterator() -> AsyncIterator[bytes]: aiterator=aiterator(), aclose_func=on_close ) - return 200, [], stream, {} + try: + yield 200, [], stream, {} + finally: + await stream.aclose() async def aclose(self): pass @@ -64,13 +69,7 @@ def _create_connection(self, **kwargs): async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - body = [] - async for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - await stream.aclose() + return b"".join([chunk async for chunk in stream]) @pytest.mark.trio @@ -80,21 +79,25 @@ async def test_sequential_requests(http_version) -> None: info = await http.get_connection_info() assert info == {} - response = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with http.arequest( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + await read_body(stream) - await read_body(stream) info = await http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} - response = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with http.arequest( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + await read_body(stream) - await read_body(stream) info = await http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} @@ -105,25 +108,36 @@ async def test_concurrent_requests_h11() -> None: info = await http.get_connection_info() assert info == {} - response_1 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with AsyncExitStack() as exit_stack2: + async with AsyncExitStack() as exit_stack1: + response_1 = await exit_stack1.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + response_2 = await exit_stack2.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 + info = await http.get_connection_info() + assert info == { + "http://example.org": [ + "ConnectionState.ACTIVE", + "ConnectionState.ACTIVE", + ] + } + + await read_body(stream_1) + + info = await http.get_connection_info() + assert info == { + "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] + } + + await read_body(stream_2) - response_2 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = await http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.ACTIVE"] - } - - await read_body(stream_1) - info = await http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] - } - - await read_body(stream_2) info = await http.get_connection_info() assert info == { "http://example.org": ["ConnectionState.IDLE", "ConnectionState.IDLE"] @@ -136,20 +150,29 @@ async def test_concurrent_requests_h2() -> None: info = await http.get_connection_info() assert info == {} - response_1 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + async with AsyncExitStack() as exit_stack2: + async with AsyncExitStack() as exit_stack1: + response_1 = await exit_stack1.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} - response_2 = await http.arequest(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + response_2 = await exit_stack2.enter_async_context( + http.arequest(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 - await read_body(stream_1) - info = await http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + await read_body(stream_1) + + info = await http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + await read_body(stream_2) - await read_body(stream_2) info = await http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index e5921d73..037ea54a 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -1,5 +1,7 @@ import platform import ssl +from contextlib import AsyncExitStack +from functools import partial import pytest @@ -15,13 +17,7 @@ def backend(request): async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - body = [] - async for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - await stream.aclose() + return b"".join([chunk async for chunk in stream]) @pytest.mark.anyio @@ -30,8 +26,9 @@ async def test_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -44,8 +41,9 @@ async def test_https_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -59,7 +57,8 @@ async def test_request_unsupported_protocol(backend: str) -> None: url = (b"ftp", b"example.org", 443, b"/") headers = [(b"host", b"example.org")] with pytest.raises(httpcore.UnsupportedProtocol): - await http.arequest(method, url, headers) + async with http.arequest(method, url, headers): + pass # pragma: no cover @pytest.mark.anyio @@ -68,8 +67,9 @@ async def test_http2_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/2"} @@ -82,8 +82,9 @@ async def test_closing_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header, (b"connection", b"close")] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -96,8 +97,9 @@ async def test_http_request_reuse_connection(backend: str, server: Server) -> No method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -106,8 +108,9 @@ async def test_http_request_reuse_connection(backend: str, server: Server) -> No method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -122,8 +125,9 @@ async def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -132,8 +136,9 @@ async def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -148,8 +153,9 @@ async def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -162,8 +168,9 @@ async def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -185,8 +192,9 @@ async def test_http_proxy( max_connections=max_connections, backend=backend, ) as http: - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -203,8 +211,9 @@ async def test_http_request_local_address(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -233,8 +242,9 @@ async def test_proxy_https_requests( max_connections=max_connections, http2=http2, ) as http: - status_code, headers, stream, ext = await http.arequest(method, url, headers) - _ = await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + _ = await read_body(stream) assert status_code == 200 assert ext["http_version"] == "HTTP/2" if http2 else "HTTP/1.1" @@ -286,15 +296,20 @@ async def test_connection_pool_get_connection_info( url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - _, _, stream_1, _ = await http.arequest(method, url, headers) - _, _, stream_2, _ = await http.arequest(method, url, headers) - - try: - stats = await http.get_connection_info() - assert stats == expected_during_active - finally: - await read_body(stream_1) - await read_body(stream_2) + async with AsyncExitStack() as exit_stack: + _, _, stream_1, _ = await exit_stack.enter_async_context( + http.arequest(method, url, headers) + ) + _, _, stream_2, _ = await exit_stack.enter_async_context( + http.arequest(method, url, headers) + ) + + try: + stats = await http.get_connection_info() + assert stats == expected_during_active + finally: + await read_body(stream_1) + await read_body(stream_2) stats = await http.get_connection_info() assert stats == expected_during_idle @@ -317,11 +332,12 @@ async def test_http_request_unix_domain_socket( method = b"GET" url = (b"http", b"localhost", None, b"/") headers = [(b"host", b"localhost")] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - assert status_code == 200 - assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} - body = await read_body(stream) - assert body == b"Hello, world!" + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + assert status_code == 200 + assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} + body = await read_body(stream) + assert body == b"Hello, world!" @pytest.mark.parametrize("max_keepalive", [1, 3, 5]) @@ -337,19 +353,17 @@ async def test_max_keepalive_connections_handled_correctly( url = (b"http", *server.netloc, b"/") headers = [server.host_header] - connections_streams = [] - for _ in range(connections_number): - _, _, stream, _ = await http.arequest(method, url, headers) - connections_streams.append(stream) + async with AsyncExitStack() as exit_stack: + for _ in range(connections_number): + _, _, stream, _ = await exit_stack.enter_async_context( + http.arequest(method, url, headers) + ) + exit_stack.push_async_callback(partial(read_body, stream)) - try: - for i in range(len(connections_streams)): - await read_body(connections_streams[i]) - finally: - stats = await http.get_connection_info() + stats = await http.get_connection_info() - connections_in_pool = next(iter(stats.values())) - assert len(connections_in_pool) == min(connections_number, max_keepalive) + connections_in_pool = next(iter(stats.values())) + assert len(connections_in_pool) == min(connections_number, max_keepalive) @pytest.mark.anyio @@ -358,8 +372,9 @@ async def test_explicit_backend_name(server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = await http.arequest(method, url, headers) - await read_body(stream) + async with http.arequest(method, url, headers) as response: + status_code, headers, stream, ext = response + await read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} diff --git a/tests/sync_tests/test_connection_pool.py b/tests/sync_tests/test_connection_pool.py index 312f96a0..0159ac4f 100644 --- a/tests/sync_tests/test_connection_pool.py +++ b/tests/sync_tests/test_connection_pool.py @@ -1,3 +1,4 @@ +from contextlib import ExitStack, contextmanager from typing import Iterator, Tuple import pytest @@ -7,7 +8,7 @@ from httpcore._types import URL, Headers -class MockConnection(object): +class MockConnection(httpcore.SyncHTTPTransport): def __init__(self, http_version): self.origin = (b"http", b"example.org", 80) self.state = ConnectionState.PENDING @@ -15,6 +16,7 @@ def __init__(self, http_version): self.is_http2 = http_version == "HTTP/2" self.stream_count = 0 + @contextmanager def request( self, method: bytes, @@ -22,7 +24,7 @@ def request( headers: Headers = None, stream: httpcore.SyncByteStream = None, ext: dict = None, - ) -> Tuple[int, Headers, httpcore.SyncByteStream, dict]: + ) -> Iterator[Tuple[int, Headers, httpcore.SyncByteStream, dict]]: self.state = ConnectionState.ACTIVE self.stream_count += 1 @@ -38,7 +40,10 @@ def iterator() -> Iterator[bytes]: iterator=iterator(), close_func=on_close ) - return 200, [], stream, {} + try: + yield 200, [], stream, {} + finally: + stream.close() def close(self): pass @@ -64,13 +69,7 @@ def _create_connection(self, **kwargs): def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - body = [] - for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - stream.close() + return b"".join([chunk for chunk in stream]) @@ -80,21 +79,25 @@ def test_sequential_requests(http_version) -> None: info = http.get_connection_info() assert info == {} - response = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with http.request( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + read_body(stream) - read_body(stream) info = http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} - response = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code, headers, stream, ext = response - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with http.request( + b"GET", (b"http", b"example.org", None, b"/") + ) as response: + status_code, headers, stream, ext = response + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + read_body(stream) - read_body(stream) info = http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} @@ -105,25 +108,36 @@ def test_concurrent_requests_h11() -> None: info = http.get_connection_info() assert info == {} - response_1 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with ExitStack() as exit_stack2: + with ExitStack() as exit_stack1: + response_1 = exit_stack1.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + response_2 = exit_stack2.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 + info = http.get_connection_info() + assert info == { + "http://example.org": [ + "ConnectionState.ACTIVE", + "ConnectionState.ACTIVE", + ] + } + + read_body(stream_1) + + info = http.get_connection_info() + assert info == { + "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] + } + + read_body(stream_2) - response_2 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.ACTIVE"] - } - - read_body(stream_1) - info = http.get_connection_info() - assert info == { - "http://example.org": ["ConnectionState.ACTIVE", "ConnectionState.IDLE"] - } - - read_body(stream_2) info = http.get_connection_info() assert info == { "http://example.org": ["ConnectionState.IDLE", "ConnectionState.IDLE"] @@ -136,20 +150,29 @@ def test_concurrent_requests_h2() -> None: info = http.get_connection_info() assert info == {} - response_1 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + with ExitStack() as exit_stack2: + with ExitStack() as exit_stack1: + response_1 = exit_stack1.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_1, headers_1, stream_1, ext_1 = response_1 + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} - response_2 = http.request(b"GET", (b"http", b"example.org", None, b"/")) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + response_2 = exit_stack2.enter_context( + http.request(b"GET", (b"http", b"example.org", None, b"/")) + ) + status_code_2, headers_2, stream_2, ext_2 = response_2 - read_body(stream_1) - info = http.get_connection_info() - assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + read_body(stream_1) + + info = http.get_connection_info() + assert info == {"http://example.org": ["ConnectionState.ACTIVE"]} + + read_body(stream_2) - read_body(stream_2) info = http.get_connection_info() assert info == {"http://example.org": ["ConnectionState.IDLE"]} diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index a175a460..db88e429 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -1,5 +1,7 @@ import platform import ssl +from contextlib import ExitStack +from functools import partial import pytest @@ -15,13 +17,7 @@ def backend(request): def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - body = [] - for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - stream.close() + return b"".join([chunk for chunk in stream]) @@ -30,8 +26,9 @@ def test_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -44,8 +41,9 @@ def test_https_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -59,7 +57,8 @@ def test_request_unsupported_protocol(backend: str) -> None: url = (b"ftp", b"example.org", 443, b"/") headers = [(b"host", b"example.org")] with pytest.raises(httpcore.UnsupportedProtocol): - http.request(method, url, headers) + with http.request(method, url, headers): + pass # pragma: no cover @@ -68,8 +67,9 @@ def test_http2_request(backend: str, https_server: Server) -> None: method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/2"} @@ -82,8 +82,9 @@ def test_closing_http_request(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header, (b"connection", b"close")] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -96,8 +97,9 @@ def test_http_request_reuse_connection(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -106,8 +108,9 @@ def test_http_request_reuse_connection(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -122,8 +125,9 @@ def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -132,8 +136,9 @@ def test_https_request_reuse_connection( method = b"GET" url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -148,8 +153,9 @@ def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -162,8 +168,9 @@ def test_http_request_cannot_reuse_dropped_connection( method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -185,8 +192,9 @@ def test_http_proxy( max_connections=max_connections, backend=backend, ) as http: - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -203,8 +211,9 @@ def test_http_request_local_address(backend: str, server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} @@ -233,8 +242,9 @@ def test_proxy_https_requests( max_connections=max_connections, http2=http2, ) as http: - status_code, headers, stream, ext = http.request(method, url, headers) - _ = read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + _ = read_body(stream) assert status_code == 200 assert ext["http_version"] == "HTTP/2" if http2 else "HTTP/1.1" @@ -286,15 +296,20 @@ def test_connection_pool_get_connection_info( url = (b"https", *https_server.netloc, b"/") headers = [https_server.host_header] - _, _, stream_1, _ = http.request(method, url, headers) - _, _, stream_2, _ = http.request(method, url, headers) - - try: - stats = http.get_connection_info() - assert stats == expected_during_active - finally: - read_body(stream_1) - read_body(stream_2) + with ExitStack() as exit_stack: + _, _, stream_1, _ = exit_stack.enter_context( + http.request(method, url, headers) + ) + _, _, stream_2, _ = exit_stack.enter_context( + http.request(method, url, headers) + ) + + try: + stats = http.get_connection_info() + assert stats == expected_during_active + finally: + read_body(stream_1) + read_body(stream_2) stats = http.get_connection_info() assert stats == expected_during_idle @@ -317,11 +332,12 @@ def test_http_request_unix_domain_socket( method = b"GET" url = (b"http", b"localhost", None, b"/") headers = [(b"host", b"localhost")] - status_code, headers, stream, ext = http.request(method, url, headers) - assert status_code == 200 - assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} - body = read_body(stream) - assert body == b"Hello, world!" + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + assert status_code == 200 + assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} + body = read_body(stream) + assert body == b"Hello, world!" @pytest.mark.parametrize("max_keepalive", [1, 3, 5]) @@ -337,19 +353,17 @@ def test_max_keepalive_connections_handled_correctly( url = (b"http", *server.netloc, b"/") headers = [server.host_header] - connections_streams = [] - for _ in range(connections_number): - _, _, stream, _ = http.request(method, url, headers) - connections_streams.append(stream) + with ExitStack() as exit_stack: + for _ in range(connections_number): + _, _, stream, _ = exit_stack.enter_context( + http.request(method, url, headers) + ) + exit_stack.callback(partial(read_body, stream)) - try: - for i in range(len(connections_streams)): - read_body(connections_streams[i]) - finally: - stats = http.get_connection_info() + stats = http.get_connection_info() - connections_in_pool = next(iter(stats.values())) - assert len(connections_in_pool) == min(connections_number, max_keepalive) + connections_in_pool = next(iter(stats.values())) + assert len(connections_in_pool) == min(connections_number, max_keepalive) @@ -358,8 +372,9 @@ def test_explicit_backend_name(server: Server) -> None: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, headers, stream, ext = response + read_body(stream) assert status_code == 200 assert ext == {"http_version": "HTTP/1.1", "reason": "OK"} diff --git a/unasync.py b/unasync.py index d3b36993..b6f19cc9 100755 --- a/unasync.py +++ b/unasync.py @@ -6,6 +6,11 @@ SUBS = [ ('AsyncIteratorByteStream', 'IteratorByteStream'), ('AsyncIterator', 'Iterator'), + ('asynccontextmanager', 'contextmanager'), + ('AsyncContextManager', 'ContextManager'), + ('AsyncExitStack', 'ExitStack'), + ('enter_async_context', 'enter_context'), + ('push_async_callback', 'callback'), ('AutoBackend', 'SyncBackend'), ('Async([A-Z][A-Za-z0-9_]*)', r'Sync\2'), ('async def', 'def'), From dcf015e3264817a8a24c2ecc543a86cbb658351d Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 3 Oct 2020 12:27:09 +0200 Subject: [PATCH 02/17] Fix Python 3.6 compatibility --- httpcore/_compat.py | 2 +- tests/async_tests/test_connection_pool.py | 2 +- tests/async_tests/test_interfaces.py | 2 +- tests/sync_tests/test_connection_pool.py | 2 +- tests/sync_tests/test_interfaces.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/httpcore/_compat.py b/httpcore/_compat.py index f62aa340..3191536f 100644 --- a/httpcore/_compat.py +++ b/httpcore/_compat.py @@ -1,5 +1,5 @@ try: - from contextlib import AsyncExitStack, asynccontextmanager + from contextlib import AsyncExitStack, asynccontextmanager # type: ignore # Py3.6 except ImportError: # pragma: no cover # Python 3.6 from async_exit_stack import AsyncExitStack # type: ignore # noqa: F401 diff --git a/tests/async_tests/test_connection_pool.py b/tests/async_tests/test_connection_pool.py index dd4039d5..6734a928 100644 --- a/tests/async_tests/test_connection_pool.py +++ b/tests/async_tests/test_connection_pool.py @@ -1,10 +1,10 @@ -from contextlib import AsyncExitStack, asynccontextmanager from typing import AsyncIterator, Tuple import pytest import httpcore from httpcore._async.base import ConnectionState +from httpcore._compat import AsyncExitStack, asynccontextmanager from httpcore._types import URL, Headers diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 037ea54a..f2503b53 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -1,11 +1,11 @@ import platform import ssl -from contextlib import AsyncExitStack from functools import partial import pytest import httpcore +from httpcore._compat import AsyncExitStack from httpcore._types import URL from tests.conftest import HTTPS_SERVER_URL, UvicornServer from tests.utils import Server, lookup_async_backend diff --git a/tests/sync_tests/test_connection_pool.py b/tests/sync_tests/test_connection_pool.py index 0159ac4f..8fe31ea5 100644 --- a/tests/sync_tests/test_connection_pool.py +++ b/tests/sync_tests/test_connection_pool.py @@ -1,10 +1,10 @@ -from contextlib import ExitStack, contextmanager from typing import Iterator, Tuple import pytest import httpcore from httpcore._async.base import ConnectionState +from httpcore._compat import ExitStack, contextmanager from httpcore._types import URL, Headers diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index db88e429..cc4af283 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -1,11 +1,11 @@ import platform import ssl -from contextlib import ExitStack from functools import partial import pytest import httpcore +from httpcore._compat import ExitStack from httpcore._types import URL from tests.conftest import HTTPS_SERVER_URL, UvicornServer from tests.utils import Server, lookup_sync_backend From 4530efe855b4d3ad36c120691f1c8e2d0309cb1a Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 3 Oct 2020 12:27:56 +0200 Subject: [PATCH 03/17] Sync docs index snippets --- docs/index.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index 7eb222e7..3ced129f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,7 +47,8 @@ with httpcore.SyncConnectionPool() as http: method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) as (status_code, headers, stream, ext): + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk for chunk in stream]) print(status_code, body) @@ -61,7 +62,8 @@ async with httpcore.AsyncConnectionPool() as http: method=b'GET', url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] - ) as (status_code, headers, stream, ext): + ) as response: + status_code, headers, stream, ext = response body = b''.join([chunk async for chunk in stream]) print(status_code, body) From 6be769867e43af505b14340cc9b014cd32b03f8e Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 3 Oct 2020 12:35:30 +0200 Subject: [PATCH 04/17] Tweak call order of _response_closed --- httpcore/_async/http_proxy.py | 23 ++++++++++------------- httpcore/_sync/http_proxy.py | 23 ++++++++++------------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index f1ee8487..8dc92029 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -167,13 +167,13 @@ async def _forward_request( url = self.proxy_origin + (target,) headers = merge_headers(self.proxy_headers, headers) - async with connection.arequest( - method, url, headers=headers, stream=stream, ext=ext - ) as response: - try: + try: + async with connection.arequest( + method, url, headers=headers, stream=stream, ext=ext + ) as response: yield response - finally: - await self._response_closed(connection) + finally: + await self._response_closed(connection) @asynccontextmanager async def _tunnel_request( @@ -253,17 +253,14 @@ async def _tunnel_request( # Once the connection has been established we can send requests on # it as normal. - response = await exit_stack.enter_async_context( - connection.arequest( + try: + async with connection.arequest( method, url, headers=headers, stream=stream, ext=ext, - ) - ) - - try: - yield response + ) as response: + yield response finally: await self._response_closed(connection) diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index 01745ce7..0045d583 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -167,13 +167,13 @@ def _forward_request( url = self.proxy_origin + (target,) headers = merge_headers(self.proxy_headers, headers) - with connection.request( - method, url, headers=headers, stream=stream, ext=ext - ) as response: - try: + try: + with connection.request( + method, url, headers=headers, stream=stream, ext=ext + ) as response: yield response - finally: - self._response_closed(connection) + finally: + self._response_closed(connection) @contextmanager def _tunnel_request( @@ -253,17 +253,14 @@ def _tunnel_request( # Once the connection has been established we can send requests on # it as normal. - response = exit_stack.enter_context( - connection.request( + try: + with connection.request( method, url, headers=headers, stream=stream, ext=ext, - ) - ) - - try: - yield response + ) as response: + yield response finally: self._response_closed(connection) From 35e006ac9e3a3447965cca3ad6822cb5e6d01aac Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 8 Oct 2020 20:21:07 +0200 Subject: [PATCH 05/17] Lint --- tests/async_tests/test_interfaces.py | 1 - tests/sync_tests/test_interfaces.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 87dc47ce..db5e9e0e 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -1,5 +1,4 @@ import platform -import ssl from functools import partial import pytest diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 9781db16..5d367cb3 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -1,5 +1,4 @@ import platform -import ssl from functools import partial import pytest From 793f2850bc4081b44b2e7e014a099857d9728b08 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 16 Nov 2020 11:26:25 +0000 Subject: [PATCH 06/17] Drop exitstack from connection_pool implementation --- httpcore/_async/connection_pool.py | 68 ++++++++++++------------------ httpcore/_sync/connection_pool.py | 68 ++++++++++++------------------ 2 files changed, 56 insertions(+), 80 deletions(-) diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 82650a06..d0c4ee6f 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -1,11 +1,10 @@ import warnings -from functools import partial from ssl import SSLContext from typing import AsyncIterator, Dict, List, Optional, Set, Tuple, cast from .._backends.auto import AsyncLock, AsyncSemaphore from .._backends.base import lookup_async_backend -from .._compat import AsyncExitStack, asynccontextmanager +from .._compat import asynccontextmanager from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict @@ -150,45 +149,34 @@ async def arequest( await self._keepalive_sweep() - async with AsyncExitStack() as exit_stack: - connection: Optional[AsyncHTTPConnection] = None - while connection is None: - async with self._connection_acquiry_lock: - # We get-or-create a connection as an atomic operation, to ensure - # that HTTP/2 requests issued in close concurrency will end up - # on the same connection. - logger.trace("get_connection_from_pool=%r", origin) - connection = await self._get_connection_from_pool(origin) - - if connection is None: - connection = self._create_connection(origin=origin) - logger.trace("created connection=%r", connection) - await self._add_to_pool(connection, timeout=timeout) - else: - logger.trace("reuse connection=%r", connection) - - try: - # Push this callback onto the stack *before* making the request, - # so that it's effectively executed *after* the response is closed. - exit_stack.push_async_callback( - partial(self._response_closed, connection) - ) - - response = await exit_stack.enter_async_context( - connection.arequest( - method, url, headers=headers, stream=stream, ext=ext - ) - ) - except NewConnectionRequired: - exit_stack.pop_all() # Drop any registered callbacks. - connection = None - except Exception: # noqa: PIE786 - logger.trace("remove from pool connection=%r", connection) - exit_stack.pop_all() # Drop any registered callbacks. - await self._remove_from_pool(connection) - raise + connection: Optional[AsyncHTTPConnection] = None + while connection is None: + async with self._connection_acquiry_lock: + # We get-or-create a connection as an atomic operation, to ensure + # that HTTP/2 requests issued in close concurrency will end up + # on the same connection. + logger.trace("get_connection_from_pool=%r", origin) + connection = await self._get_connection_from_pool(origin) + + if connection is None: + connection = self._create_connection(origin=origin) + logger.trace("created connection=%r", connection) + await self._add_to_pool(connection, timeout=timeout) + else: + logger.trace("reuse connection=%r", connection) + + try: + async with connection.arequest( + method, url, headers=headers, stream=stream, ext=ext + ) as response: + yield response + except NewConnectionRequired: + connection = None + except Exception: # noqa: PIE786 + logger.trace("remove from pool connection=%r", connection) + await self._remove_from_pool(connection) - yield response + await self._response_closed(connection) async def _get_connection_from_pool( self, origin: Origin diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 3f5f566d..5c2298d2 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -1,11 +1,10 @@ import warnings -from functools import partial from ssl import SSLContext from typing import Iterator, Dict, List, Optional, Set, Tuple, cast from .._backends.sync import SyncLock, SyncSemaphore from .._backends.base import lookup_sync_backend -from .._compat import ExitStack, contextmanager +from .._compat import contextmanager from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict @@ -150,45 +149,34 @@ def request( self._keepalive_sweep() - with ExitStack() as exit_stack: - connection: Optional[SyncHTTPConnection] = None - while connection is None: - with self._connection_acquiry_lock: - # We get-or-create a connection as an atomic operation, to ensure - # that HTTP/2 requests issued in close concurrency will end up - # on the same connection. - logger.trace("get_connection_from_pool=%r", origin) - connection = self._get_connection_from_pool(origin) - - if connection is None: - connection = self._create_connection(origin=origin) - logger.trace("created connection=%r", connection) - self._add_to_pool(connection, timeout=timeout) - else: - logger.trace("reuse connection=%r", connection) - - try: - # Push this callback onto the stack *before* making the request, - # so that it's effectively executed *after* the response is closed. - exit_stack.callback( - partial(self._response_closed, connection) - ) - - response = exit_stack.enter_context( - connection.request( - method, url, headers=headers, stream=stream, ext=ext - ) - ) - except NewConnectionRequired: - exit_stack.pop_all() # Drop any registered callbacks. - connection = None - except Exception: # noqa: PIE786 - logger.trace("remove from pool connection=%r", connection) - exit_stack.pop_all() # Drop any registered callbacks. - self._remove_from_pool(connection) - raise + connection: Optional[SyncHTTPConnection] = None + while connection is None: + with self._connection_acquiry_lock: + # We get-or-create a connection as an atomic operation, to ensure + # that HTTP/2 requests issued in close concurrency will end up + # on the same connection. + logger.trace("get_connection_from_pool=%r", origin) + connection = self._get_connection_from_pool(origin) + + if connection is None: + connection = self._create_connection(origin=origin) + logger.trace("created connection=%r", connection) + self._add_to_pool(connection, timeout=timeout) + else: + logger.trace("reuse connection=%r", connection) + + try: + with connection.request( + method, url, headers=headers, stream=stream, ext=ext + ) as response: + yield response + except NewConnectionRequired: + connection = None + except Exception: # noqa: PIE786 + logger.trace("remove from pool connection=%r", connection) + self._remove_from_pool(connection) - yield response + self._response_closed(connection) def _get_connection_from_pool( self, origin: Origin From aa99f4b67a442f33b46633322a33e45a9bf9731f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 16 Nov 2020 11:41:27 +0000 Subject: [PATCH 07/17] Drop ResponseStream.close in http11 implementation --- httpcore/_async/http11.py | 7 ++----- httpcore/_sync/http11.py | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index f31d915d..455aa887 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -74,16 +74,13 @@ async def arequest( ) = await self._receive_response(timeout) response_stream = AsyncIteratorByteStream( aiterator=self._receive_response_data(timeout), - aclose_func=self._response_closed, ) ext = { "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - try: - yield (status_code, headers, response_stream, ext) - finally: - await response_stream.aclose() + yield (status_code, headers, response_stream, ext) + await self._response_closed() async def start_tls( self, hostname: bytes, timeout: TimeoutDict = None diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index 45f170e3..10f7b55b 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -74,16 +74,13 @@ def request( ) = self._receive_response(timeout) response_stream = IteratorByteStream( iterator=self._receive_response_data(timeout), - close_func=self._response_closed, ) ext = { "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - try: - yield (status_code, headers, response_stream, ext) - finally: - response_stream.close() + yield (status_code, headers, response_stream, ext) + self._response_closed() def start_tls( self, hostname: bytes, timeout: TimeoutDict = None From 02de08afb25b9443df3a9bd39a760fd5b7be5670 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 16 Nov 2020 11:44:01 +0000 Subject: [PATCH 08/17] Drop ResponseStream.close in http2 implementation --- httpcore/_async/http2.py | 10 +++------- httpcore/_sync/http2.py | 10 +++------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 42b9bad0..38a422b2 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -301,18 +301,14 @@ async def arequest( # Receive the response. status_code, headers = await self.receive_response(timeout) - response_stream = AsyncIteratorByteStream( - aiterator=self.body_iter(timeout), aclose_func=self._response_closed - ) + response_stream = AsyncIteratorByteStream(aiterator=self.body_iter(timeout)) ext = { "http_version": "HTTP/2", } - try: - yield (status_code, headers, response_stream, ext) - finally: - await response_stream.aclose() + yield (status_code, headers, response_stream, ext) + await self._response_closed() async def send_headers( self, diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 06bb72d6..cec63d42 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -301,18 +301,14 @@ def request( # Receive the response. status_code, headers = self.receive_response(timeout) - response_stream = IteratorByteStream( - iterator=self.body_iter(timeout), close_func=self._response_closed - ) + response_stream = IteratorByteStream(iterator=self.body_iter(timeout)) ext = { "http_version": "HTTP/2", } - try: - yield (status_code, headers, response_stream, ext) - finally: - response_stream.close() + yield (status_code, headers, response_stream, ext) + self._response_closed() def send_headers( self, From fdccfafc0534c6ab456f2150ee8c78f30fd402b2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 16 Nov 2020 13:19:16 +0000 Subject: [PATCH 09/17] Resolve typo in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 342863c0..3ced129f 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ with httpcore.SyncConnectionPool() as http: url=(b'https', b'example.org', 443, b'/'), headers=[(b'host', b'example.org'), (b'user-agent', 'httpcore')] ) as response: - status_code, headers, stream, ext = respnose + status_code, headers, stream, ext = response body = b''.join([chunk for chunk in stream]) print(status_code, body) From 9d7885e9a1f9b0e93e1eced68940b3e01106736a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 17 Nov 2020 12:11:34 +0000 Subject: [PATCH 10/17] Ensure response_closed is called --- httpcore/_async/http11.py | 6 ++++-- httpcore/_async/http2.py | 6 ++++-- httpcore/_sync/http11.py | 6 ++++-- httpcore/_sync/http2.py | 6 ++++-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 2a73a45f..b88e17ab 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -79,8 +79,10 @@ async def arequest( "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - yield (status_code, headers, response_stream, ext) - await self._response_closed() + try: + yield (status_code, headers, response_stream, ext) + finally: + await self._response_closed() async def start_tls( self, hostname: bytes, timeout: TimeoutDict = None diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 5b8f862a..43cf371a 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -307,8 +307,10 @@ async def arequest( "http_version": "HTTP/2", } - yield (status_code, headers, response_stream, ext) - await self._response_closed() + try: + yield (status_code, headers, response_stream, ext) + finally: + await self._response_closed() async def send_headers( self, diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index ad74df4c..e7afd381 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -79,8 +79,10 @@ def request( "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), } - yield (status_code, headers, response_stream, ext) - self._response_closed() + try: + yield (status_code, headers, response_stream, ext) + finally: + self._response_closed() def start_tls( self, hostname: bytes, timeout: TimeoutDict = None diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index aeb917b4..2c528e45 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -307,8 +307,10 @@ def request( "http_version": "HTTP/2", } - yield (status_code, headers, response_stream, ext) - self._response_closed() + try: + yield (status_code, headers, response_stream, ext) + finally: + self._response_closed() def send_headers( self, From a92a82df852c57bb7e901e112edc9dce0c8cc30b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 18 Nov 2020 09:57:33 +0000 Subject: [PATCH 11/17] Drop close on bytestream interface --- httpcore/_async/base.py | 6 ------ httpcore/_bytestreams.py | 18 +++--------------- httpcore/_sync/base.py | 6 ------ tests/async_tests/test_connection_pool.py | 6 ++---- tests/sync_tests/test_connection_pool.py | 6 ++---- 5 files changed, 7 insertions(+), 35 deletions(-) diff --git a/httpcore/_async/base.py b/httpcore/_async/base.py index 3d6be7a2..a8a65261 100644 --- a/httpcore/_async/base.py +++ b/httpcore/_async/base.py @@ -46,12 +46,6 @@ async def __aiter__(self) -> AsyncIterator[bytes]: """ yield b"" # pragma: nocover - async def aclose(self) -> None: - """ - Must be called by the client to indicate that the stream has been closed. - """ - pass # pragma: nocover - class AsyncHTTPTransport: """ diff --git a/httpcore/_bytestreams.py b/httpcore/_bytestreams.py index e938aaf9..aded338d 100644 --- a/httpcore/_bytestreams.py +++ b/httpcore/_bytestreams.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Callable, Iterator +from typing import AsyncIterator, Iterator from ._async.base import AsyncByteStream from ._sync.base import SyncByteStream @@ -37,18 +37,13 @@ def generate_content(): ``` """ - def __init__(self, iterator: Iterator[bytes], close_func: Callable = None) -> None: + def __init__(self, iterator: Iterator[bytes]) -> None: self._iterator = iterator - self._close_func = close_func def __iter__(self) -> Iterator[bytes]: for chunk in self._iterator: yield chunk - def close(self) -> None: - if self._close_func is not None: - self._close_func() - class AsyncIteratorByteStream(AsyncByteStream): """ @@ -63,16 +58,9 @@ async def generate_content(): ``` """ - def __init__( - self, aiterator: AsyncIterator[bytes], aclose_func: Callable = None - ) -> None: + def __init__(self, aiterator: AsyncIterator[bytes]) -> None: self._aiterator = aiterator - self._aclose_func = aclose_func async def __aiter__(self) -> AsyncIterator[bytes]: async for chunk in self._aiterator: yield chunk - - async def aclose(self) -> None: - if self._aclose_func is not None: - await self._aclose_func() diff --git a/httpcore/_sync/base.py b/httpcore/_sync/base.py index 519735eb..7e9097fd 100644 --- a/httpcore/_sync/base.py +++ b/httpcore/_sync/base.py @@ -46,12 +46,6 @@ def __iter__(self) -> Iterator[bytes]: """ yield b"" # pragma: nocover - def close(self) -> None: - """ - Must be called by the client to indicate that the stream has been closed. - """ - pass # pragma: nocover - class SyncHTTPTransport: """ diff --git a/tests/async_tests/test_connection_pool.py b/tests/async_tests/test_connection_pool.py index d6ce39b7..0fa003f3 100644 --- a/tests/async_tests/test_connection_pool.py +++ b/tests/async_tests/test_connection_pool.py @@ -36,14 +36,12 @@ async def on_close(): async def aiterator() -> AsyncIterator[bytes]: yield b"" - stream = httpcore.AsyncIteratorByteStream( - aiterator=aiterator(), aclose_func=on_close - ) + stream = httpcore.AsyncIteratorByteStream(aiterator=aiterator()) try: yield 200, [], stream, {} finally: - await stream.aclose() + await on_close() async def aclose(self): pass diff --git a/tests/sync_tests/test_connection_pool.py b/tests/sync_tests/test_connection_pool.py index 757702c1..0c4937e4 100644 --- a/tests/sync_tests/test_connection_pool.py +++ b/tests/sync_tests/test_connection_pool.py @@ -36,14 +36,12 @@ def on_close(): def iterator() -> Iterator[bytes]: yield b"" - stream = httpcore.IteratorByteStream( - iterator=iterator(), close_func=on_close - ) + stream = httpcore.IteratorByteStream(iterator=iterator()) try: yield 200, [], stream, {} finally: - stream.close() + on_close() def close(self): pass From 3699f46a8898bc27712d301c268dee80007a8834 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 18 Nov 2020 10:00:35 +0000 Subject: [PATCH 12/17] Drop bytestream.close in docs API reference --- docs/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api.md b/docs/api.md index 3bbde423..e9505af7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -11,7 +11,7 @@ interface which transport classes need to implement. ::: httpcore.AsyncByteStream :docstring: - :members: __aiter__ aclose + :members: __aiter__ The `AsyncConnectionPool` class is a concrete implementation of `AsyncHTTPTransport`. @@ -40,7 +40,7 @@ interface which transport classes need to implement. ::: httpcore.SyncByteStream :docstring: - :members: __iter__ close + :members: __iter__ The `SyncConnectionPool` class is a concrete implementation of `SyncHTTPTransport`. From 262464c8bb3bc0b27003861f09f09a5ef1e9298c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 18 Nov 2020 11:08:40 +0000 Subject: [PATCH 13/17] ByteStream -> Iterable[bytes] --- httpcore/__init__.py | 4 +- httpcore/_async/base.py | 10 ++--- httpcore/_async/connection.py | 13 ++----- httpcore/_async/connection_pool.py | 23 ++++++----- httpcore/_async/http11.py | 16 ++++---- httpcore/_async/http2.py | 20 +++++----- httpcore/_async/http_proxy.py | 15 ++++---- httpcore/_bytestreams.py | 47 +---------------------- httpcore/_sync/base.py | 6 +-- httpcore/_sync/connection.py | 13 ++----- httpcore/_sync/connection_pool.py | 23 ++++++----- httpcore/_sync/http11.py | 16 ++++---- httpcore/_sync/http2.py | 20 +++++----- httpcore/_sync/http_proxy.py | 15 ++++---- tests/async_tests/test_connection_pool.py | 10 ++--- tests/async_tests/test_interfaces.py | 3 +- tests/async_tests/test_retries.py | 4 +- tests/sync_tests/test_connection_pool.py | 10 ++--- tests/sync_tests/test_interfaces.py | 3 +- tests/sync_tests/test_retries.py | 4 +- unasync.py | 1 + 21 files changed, 115 insertions(+), 161 deletions(-) diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 3aedf7d4..f5f2d735 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,7 +1,7 @@ from ._async.base import AsyncByteStream, AsyncHTTPTransport from ._async.connection_pool import AsyncConnectionPool from ._async.http_proxy import AsyncHTTPProxy -from ._bytestreams import AsyncIteratorByteStream, IteratorByteStream, PlainByteStream +from ._bytestreams import PlainByteStream from ._exceptions import ( CloseError, ConnectError, @@ -28,11 +28,9 @@ "AsyncConnectionPool", "AsyncHTTPProxy", "AsyncHTTPTransport", - "AsyncIteratorByteStream", "CloseError", "ConnectError", "ConnectTimeout", - "IteratorByteStream", "LocalProtocolError", "NetworkError", "PlainByteStream", diff --git a/httpcore/_async/base.py b/httpcore/_async/base.py index a8a65261..a448ea3c 100644 --- a/httpcore/_async/base.py +++ b/httpcore/_async/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import AsyncContextManager, AsyncIterator, Tuple, Type +from typing import AsyncContextManager, AsyncIterable, AsyncIterator, Tuple, Type from .._types import URL, Headers, T @@ -60,9 +60,9 @@ def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncContextManager[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncContextManager[Tuple[int, Headers, AsyncIterable[bytes], dict]]: """ The interface for sending a single HTTP request, and returning a response. @@ -73,7 +73,7 @@ def arequest( of (scheme, host, port, path). * **headers** - `Optional[List[Tuple[bytes, bytes]]]` - Any HTTP headers to send with the request. - * **stream** - `Optional[AsyncByteStream]` - The body of the HTTP request. + * **stream** - `Optional[AsyncIterable[bytes]]` - The body of the HTTP request. * **ext** - `Optional[dict]` - A dictionary of optional extensions. ** Returns:** @@ -83,7 +83,7 @@ def arequest( * **status_code** - `int` - The HTTP status code, such as `200`. * **headers** - `List[Tuple[bytes, bytes]]` - Any HTTP headers included on the response. - * **stream** - `AsyncByteStream` - The body of the HTTP response. + * **stream** - `AsyncIterable[bytes]` - The body of the HTTP response. * **ext** - `dict` - A dictionary of optional extensions. """ raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index a015b588..32158aea 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -1,17 +1,12 @@ from ssl import SSLContext -from typing import AsyncIterator, Optional, Tuple, cast +from typing import AsyncIterable, AsyncIterator, Optional, Tuple, cast from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend from .._compat import asynccontextmanager from .._exceptions import ConnectError, ConnectTimeout from .._types import URL, Headers, Origin, TimeoutDict from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import ( - AsyncByteStream, - AsyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import AsyncHTTPTransport, ConnectionState, NewConnectionRequired from .http import AsyncBaseHTTPConnection logger = get_logger(__name__) @@ -78,9 +73,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index c8a70153..7747a35e 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -1,6 +1,16 @@ import warnings from ssl import SSLContext -from typing import AsyncIterator, Dict, List, Optional, Set, Tuple, Union, cast +from typing import ( + AsyncIterable, + AsyncIterator, + Dict, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore from .._backends.base import lookup_async_backend @@ -9,12 +19,7 @@ from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, origin_to_url_string, url_to_origin -from .base import ( - AsyncByteStream, - AsyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import AsyncHTTPTransport, ConnectionState, NewConnectionRequired from .connection import AsyncHTTPConnection logger = get_logger(__name__) @@ -142,9 +147,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index b88e17ab..25aeb006 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -1,15 +1,15 @@ from ssl import SSLContext -from typing import AsyncIterator, List, Tuple, Union, cast +from typing import AsyncIterable, AsyncIterator, List, Tuple, Union, cast import h11 from .._backends.auto import AsyncSocketStream -from .._bytestreams import AsyncIteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream from .._compat import asynccontextmanager from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import AsyncByteStream, ConnectionState +from .base import ConnectionState from .http import AsyncBaseHTTPConnection H11Event = Union[ @@ -54,9 +54,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: headers = [] if headers is None else headers stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -72,9 +72,7 @@ async def arequest( reason_phrase, headers, ) = await self._receive_response(timeout) - response_stream = AsyncIteratorByteStream( - aiterator=self._receive_response_data(timeout), - ) + response_stream = self._receive_response_data(timeout) ext = { "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), @@ -104,7 +102,7 @@ async def _send_request( await self._send_event(event, timeout) async def _send_request_body( - self, stream: AsyncByteStream, timeout: TimeoutDict + self, stream: AsyncIterable[bytes], timeout: TimeoutDict ) -> None: """ Send the request body. diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 43cf371a..be477190 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import AsyncIterator, Dict, List, Tuple, cast +from typing import AsyncIterable, AsyncIterator, Dict, List, Tuple, cast import h2.connection import h2.events @@ -8,12 +8,12 @@ from h2.settings import SettingCodes, Settings from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream -from .._bytestreams import AsyncIteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream from .._compat import asynccontextmanager from .._exceptions import PoolTimeout, RemoteProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import AsyncByteStream, ConnectionState, NewConnectionRequired +from .base import ConnectionState, NewConnectionRequired from .http import AsyncBaseHTTPConnection logger = get_logger(__name__) @@ -92,9 +92,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -281,9 +281,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: headers = [] if headers is None else [(k.lower(), v) for (k, v) in headers] stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -301,7 +301,7 @@ async def arequest( # Receive the response. status_code, headers = await self.receive_response(timeout) - response_stream = AsyncIteratorByteStream(aiterator=self.body_iter(timeout)) + response_stream = self.body_iter(timeout) ext = { "http_version": "HTTP/2", @@ -337,7 +337,9 @@ async def send_headers( await self.connection.send_headers(self.stream_id, headers, end_stream, timeout) - async def send_body(self, stream: AsyncByteStream, timeout: TimeoutDict) -> None: + async def send_body( + self, stream: AsyncIterable[bytes], timeout: TimeoutDict + ) -> None: async for data in stream: while data: max_flow = await self.connection.wait_for_outgoing_flow( diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index fc482880..031eb1fe 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -1,12 +1,11 @@ from http import HTTPStatus from ssl import SSLContext -from typing import AsyncIterator, Tuple, cast +from typing import AsyncIterable, AsyncIterator, Tuple, cast from .._compat import AsyncExitStack, asynccontextmanager from .._exceptions import ProxyError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger, url_to_origin -from .base import AsyncByteStream from .connection import AsyncHTTPConnection from .connection_pool import AsyncConnectionPool @@ -94,9 +93,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: if self._keepalive_expiry is not None: await self._keepalive_sweep() @@ -135,9 +134,9 @@ async def _forward_request( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: """ Forwarded proxy requests include the entire URL as the HTTP target, rather than just the path. @@ -181,9 +180,9 @@ async def _tunnel_request( method: bytes, url: URL, headers: Headers = None, - stream: AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: """ Tunnelled proxy requests require an initial CONNECT request to establish the connection, and then send regular requests. diff --git a/httpcore/_bytestreams.py b/httpcore/_bytestreams.py index aded338d..5eeba2ee 100644 --- a/httpcore/_bytestreams.py +++ b/httpcore/_bytestreams.py @@ -1,10 +1,7 @@ from typing import AsyncIterator, Iterator -from ._async.base import AsyncByteStream -from ._sync.base import SyncByteStream - -class PlainByteStream(AsyncByteStream, SyncByteStream): +class PlainByteStream: """ A concrete implementation for either sync or async byte streams. Just handles a plain byte string as the content of the stream. @@ -22,45 +19,3 @@ def __iter__(self) -> Iterator[bytes]: async def __aiter__(self) -> AsyncIterator[bytes]: yield self._content - - -class IteratorByteStream(SyncByteStream): - """ - A concrete implementation for sync byte streams. - Handles a byte iterator as the content of the stream. - - ``` - def generate_content(): - ... - - stream = httpcore.IteratorByteStream(generate_content()) - ``` - """ - - def __init__(self, iterator: Iterator[bytes]) -> None: - self._iterator = iterator - - def __iter__(self) -> Iterator[bytes]: - for chunk in self._iterator: - yield chunk - - -class AsyncIteratorByteStream(AsyncByteStream): - """ - A concrete implementation for async byte streams. - Handles an async byte iterator as the content of the stream. - - ``` - async def generate_content(): - ... - - stream = httpcore.AsyncIteratorByteStream(generate_content()) - ``` - """ - - def __init__(self, aiterator: AsyncIterator[bytes]) -> None: - self._aiterator = aiterator - - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self._aiterator: - yield chunk diff --git a/httpcore/_sync/base.py b/httpcore/_sync/base.py index 7e9097fd..1a1b8ae8 100644 --- a/httpcore/_sync/base.py +++ b/httpcore/_sync/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import ContextManager, Iterator, Tuple, Type +from typing import ContextManager, Iterable, Iterator, Tuple, Type from .._types import URL, Headers, T @@ -60,9 +60,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> ContextManager[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> ContextManager[Tuple[int, Headers, Iterable[bytes], dict]]: """ The interface for sending a single HTTP request, and returning a response. diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 91cc9a80..7b8aa984 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -1,17 +1,12 @@ from ssl import SSLContext -from typing import Iterator, Optional, Tuple, cast +from typing import Iterable, Iterator, Optional, Tuple, cast from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend from .._compat import contextmanager from .._exceptions import ConnectError, ConnectTimeout from .._types import URL, Headers, Origin, TimeoutDict from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import ( - SyncByteStream, - SyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import SyncHTTPTransport, ConnectionState, NewConnectionRequired from .http import SyncBaseHTTPConnection logger = get_logger(__name__) @@ -78,9 +73,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: assert url_to_origin(url) == self.origin ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 45b8effe..129ca0e5 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -1,6 +1,16 @@ import warnings from ssl import SSLContext -from typing import Iterator, Dict, List, Optional, Set, Tuple, Union, cast +from typing import ( + Iterable, + Iterator, + Dict, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore from .._backends.base import lookup_sync_backend @@ -9,12 +19,7 @@ from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, origin_to_url_string, url_to_origin -from .base import ( - SyncByteStream, - SyncHTTPTransport, - ConnectionState, - NewConnectionRequired, -) +from .base import SyncHTTPTransport, ConnectionState, NewConnectionRequired from .connection import SyncHTTPConnection logger = get_logger(__name__) @@ -142,9 +147,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: if url[0] not in (b"http", b"https"): scheme = url[0].decode("latin-1") raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index e7afd381..d940e070 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -1,15 +1,15 @@ from ssl import SSLContext -from typing import Iterator, List, Tuple, Union, cast +from typing import Iterable, Iterator, List, Tuple, Union, cast import h11 from .._backends.sync import SyncSocketStream -from .._bytestreams import IteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream from .._compat import contextmanager from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import SyncByteStream, ConnectionState +from .base import ConnectionState from .http import SyncBaseHTTPConnection H11Event = Union[ @@ -54,9 +54,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: headers = [] if headers is None else headers stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -72,9 +72,7 @@ def request( reason_phrase, headers, ) = self._receive_response(timeout) - response_stream = IteratorByteStream( - iterator=self._receive_response_data(timeout), - ) + response_stream = self._receive_response_data(timeout) ext = { "http_version": http_version.decode("ascii", errors="ignore"), "reason": reason_phrase.decode("ascii", errors="ignore"), @@ -104,7 +102,7 @@ def _send_request( self._send_event(event, timeout) def _send_request_body( - self, stream: SyncByteStream, timeout: TimeoutDict + self, stream: Iterable[bytes], timeout: TimeoutDict ) -> None: """ Send the request body. diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 2c528e45..7e73a27d 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import Iterator, Dict, List, Tuple, cast +from typing import Iterable, Iterator, Dict, List, Tuple, cast import h2.connection import h2.events @@ -8,12 +8,12 @@ from h2.settings import SettingCodes, Settings from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore, SyncSocketStream -from .._bytestreams import IteratorByteStream, PlainByteStream +from .._bytestreams import PlainByteStream from .._compat import contextmanager from .._exceptions import PoolTimeout, RemoteProtocolError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger -from .base import SyncByteStream, ConnectionState, NewConnectionRequired +from .base import ConnectionState, NewConnectionRequired from .http import SyncBaseHTTPConnection logger = get_logger(__name__) @@ -92,9 +92,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: ext = {} if ext is None else ext timeout = cast(TimeoutDict, ext.get("timeout", {})) @@ -281,9 +281,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: headers = [] if headers is None else [(k.lower(), v) for (k, v) in headers] stream = PlainByteStream(b"") if stream is None else stream ext = {} if ext is None else ext @@ -301,7 +301,7 @@ def request( # Receive the response. status_code, headers = self.receive_response(timeout) - response_stream = IteratorByteStream(iterator=self.body_iter(timeout)) + response_stream = self.body_iter(timeout) ext = { "http_version": "HTTP/2", @@ -337,7 +337,9 @@ def send_headers( self.connection.send_headers(self.stream_id, headers, end_stream, timeout) - def send_body(self, stream: SyncByteStream, timeout: TimeoutDict) -> None: + def send_body( + self, stream: Iterable[bytes], timeout: TimeoutDict + ) -> None: for data in stream: while data: max_flow = self.connection.wait_for_outgoing_flow( diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index 76be9647..b3823955 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -1,12 +1,11 @@ from http import HTTPStatus from ssl import SSLContext -from typing import Iterator, Tuple, cast +from typing import Iterable, Iterator, Tuple, cast from .._compat import ExitStack, contextmanager from .._exceptions import ProxyError from .._types import URL, Headers, TimeoutDict from .._utils import get_logger, url_to_origin -from .base import SyncByteStream from .connection import SyncHTTPConnection from .connection_pool import SyncConnectionPool @@ -94,9 +93,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: if self._keepalive_expiry is not None: self._keepalive_sweep() @@ -135,9 +134,9 @@ def _forward_request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: """ Forwarded proxy requests include the entire URL as the HTTP target, rather than just the path. @@ -181,9 +180,9 @@ def _tunnel_request( method: bytes, url: URL, headers: Headers = None, - stream: SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: """ Tunnelled proxy requests require an initial CONNECT request to establish the connection, and then send regular requests. diff --git a/tests/async_tests/test_connection_pool.py b/tests/async_tests/test_connection_pool.py index 0fa003f3..4052d985 100644 --- a/tests/async_tests/test_connection_pool.py +++ b/tests/async_tests/test_connection_pool.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Tuple +from typing import AsyncIterable, AsyncIterator, Tuple import pytest @@ -22,9 +22,9 @@ async def arequest( method: bytes, url: URL, headers: Headers = None, - stream: httpcore.AsyncByteStream = None, + stream: AsyncIterable[bytes] = None, ext: dict = None, - ) -> AsyncIterator[Tuple[int, Headers, httpcore.AsyncByteStream, dict]]: + ) -> AsyncIterator[Tuple[int, Headers, AsyncIterable[bytes], dict]]: self.state = ConnectionState.ACTIVE self.stream_count += 1 @@ -36,7 +36,7 @@ async def on_close(): async def aiterator() -> AsyncIterator[bytes]: yield b"" - stream = httpcore.AsyncIteratorByteStream(aiterator=aiterator()) + stream = aiterator() try: yield 200, [], stream, {} @@ -66,7 +66,7 @@ def _create_connection(self, **kwargs): return MockConnection(self.http_version) -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: +async def read_body(stream: AsyncIterable[bytes]) -> bytes: return b"".join([chunk async for chunk in stream]) diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 368a275d..adb135f3 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -1,5 +1,6 @@ import platform from functools import partial +from typing import AsyncIterable import pytest @@ -15,7 +16,7 @@ def backend(request): return request.param -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: +async def read_body(stream: AsyncIterable[bytes]) -> bytes: return b"".join([chunk async for chunk in stream]) diff --git a/tests/async_tests/test_retries.py b/tests/async_tests/test_retries.py index 022e441d..6478ce19 100644 --- a/tests/async_tests/test_retries.py +++ b/tests/async_tests/test_retries.py @@ -1,6 +1,6 @@ import queue import time -from typing import Any, List, Optional +from typing import Any, AsyncIterable, List, Optional import pytest @@ -32,7 +32,7 @@ async def open_tcp_stream(self, *args: Any, **kwargs: Any) -> AsyncSocketStream: return await super().open_tcp_stream(*args, **kwargs) -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: +async def read_body(stream: AsyncIterable[bytes]) -> bytes: return b"".join([chunk async for chunk in stream]) diff --git a/tests/sync_tests/test_connection_pool.py b/tests/sync_tests/test_connection_pool.py index 0c4937e4..d3d605fa 100644 --- a/tests/sync_tests/test_connection_pool.py +++ b/tests/sync_tests/test_connection_pool.py @@ -1,4 +1,4 @@ -from typing import Iterator, Tuple +from typing import Iterable, Iterator, Tuple import pytest @@ -22,9 +22,9 @@ def request( method: bytes, url: URL, headers: Headers = None, - stream: httpcore.SyncByteStream = None, + stream: Iterable[bytes] = None, ext: dict = None, - ) -> Iterator[Tuple[int, Headers, httpcore.SyncByteStream, dict]]: + ) -> Iterator[Tuple[int, Headers, Iterable[bytes], dict]]: self.state = ConnectionState.ACTIVE self.stream_count += 1 @@ -36,7 +36,7 @@ def on_close(): def iterator() -> Iterator[bytes]: yield b"" - stream = httpcore.IteratorByteStream(iterator=iterator()) + stream = iterator() try: yield 200, [], stream, {} @@ -66,7 +66,7 @@ def _create_connection(self, **kwargs): return MockConnection(self.http_version) -def read_body(stream: httpcore.SyncByteStream) -> bytes: +def read_body(stream: Iterable[bytes]) -> bytes: return b"".join([chunk for chunk in stream]) diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 20538137..05aa4462 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -1,5 +1,6 @@ import platform from functools import partial +from typing import Iterable import pytest @@ -15,7 +16,7 @@ def backend(request): return request.param -def read_body(stream: httpcore.SyncByteStream) -> bytes: +def read_body(stream: Iterable[bytes]) -> bytes: return b"".join([chunk for chunk in stream]) diff --git a/tests/sync_tests/test_retries.py b/tests/sync_tests/test_retries.py index c5ea5730..961d4b66 100644 --- a/tests/sync_tests/test_retries.py +++ b/tests/sync_tests/test_retries.py @@ -1,6 +1,6 @@ import queue import time -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional import pytest @@ -32,7 +32,7 @@ def open_tcp_stream(self, *args: Any, **kwargs: Any) -> SyncSocketStream: return super().open_tcp_stream(*args, **kwargs) -def read_body(stream: httpcore.SyncByteStream) -> bytes: +def read_body(stream: Iterable[bytes]) -> bytes: return b"".join([chunk for chunk in stream]) diff --git a/unasync.py b/unasync.py index b6f19cc9..03983233 100755 --- a/unasync.py +++ b/unasync.py @@ -6,6 +6,7 @@ SUBS = [ ('AsyncIteratorByteStream', 'IteratorByteStream'), ('AsyncIterator', 'Iterator'), + ('AsyncIterable', 'Iterable'), ('asynccontextmanager', 'contextmanager'), ('AsyncContextManager', 'ContextManager'), ('AsyncExitStack', 'ExitStack'), From 9b0cdb06c598df61f4a6f8d95919ad62c7eb1481 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 18 Nov 2020 11:14:53 +0000 Subject: [PATCH 14/17] Drop SyncByteStream, AsyncByteStream in favour of Iterable[bytes] --- docs/api.md | 33 ++++++++------------------------- httpcore/__init__.py | 6 ++---- httpcore/_async/base.py | 17 +---------------- httpcore/_sync/base.py | 21 +++------------------ unasync.py | 1 - 5 files changed, 14 insertions(+), 64 deletions(-) diff --git a/docs/api.md b/docs/api.md index e9505af7..40b075f2 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,55 +2,38 @@ ## Async API Overview -The `AsyncHTTPTransport` and `AsyncByteStream` classes provide the base -interface which transport classes need to implement. +The `AsyncHTTPTransport` class provides the base interface which transport classes need to implement. ::: httpcore.AsyncHTTPTransport :docstring: :members: arequest aclose -::: httpcore.AsyncByteStream - :docstring: - :members: __aiter__ - The `AsyncConnectionPool` class is a concrete implementation of `AsyncHTTPTransport`. ::: httpcore.AsyncConnectionPool :docstring: - -The `PlainByteStream` and `AsyncIteratorByteStream` classes are concrete implementations of `AsyncByteStream`. - -::: httpcore.PlainByteStream - :docstring: - -::: httpcore.AsyncIteratorByteStream - :docstring: - --- ## Sync API Overview -The `SyncHTTPTransport` and `SyncByteStream` classes provide the base -interface which transport classes need to implement. +The `SyncHTTPTransport` class provides the base interface which transport classes need to implement. ::: httpcore.SyncHTTPTransport :docstring: :members: request close -::: httpcore.SyncByteStream - :docstring: - :members: __iter__ - The `SyncConnectionPool` class is a concrete implementation of `SyncHTTPTransport`. ::: httpcore.SyncConnectionPool :docstring: -The `PlainByteStream` and `IteratorByteStream` classes are concrete implementations of `SyncByteStream`. +--- + +## Utilities -::: httpcore.PlainByteStream - :docstring: +The `PlainByteStream` can be used to return a bytestring with both bytes iterable +and async bytes iterable iterfaces. -::: httpcore.IteratorByteStream +::: httpcore.PlainByteStream :docstring: diff --git a/httpcore/__init__.py b/httpcore/__init__.py index f5f2d735..52b7124c 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,4 +1,4 @@ -from ._async.base import AsyncByteStream, AsyncHTTPTransport +from ._async.base import AsyncHTTPTransport from ._async.connection_pool import AsyncConnectionPool from ._async.http_proxy import AsyncHTTPProxy from ._bytestreams import PlainByteStream @@ -19,12 +19,11 @@ WriteError, WriteTimeout, ) -from ._sync.base import SyncByteStream, SyncHTTPTransport +from ._sync.base import SyncHTTPTransport from ._sync.connection_pool import SyncConnectionPool from ._sync.http_proxy import SyncHTTPProxy __all__ = [ - "AsyncByteStream", "AsyncConnectionPool", "AsyncHTTPProxy", "AsyncHTTPTransport", @@ -40,7 +39,6 @@ "ReadError", "ReadTimeout", "RemoteProtocolError", - "SyncByteStream", "SyncConnectionPool", "SyncHTTPProxy", "SyncHTTPTransport", diff --git a/httpcore/_async/base.py b/httpcore/_async/base.py index a448ea3c..024af43d 100644 --- a/httpcore/_async/base.py +++ b/httpcore/_async/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import AsyncContextManager, AsyncIterable, AsyncIterator, Tuple, Type +from typing import AsyncContextManager, AsyncIterable, Tuple, Type from .._types import URL, Headers, T @@ -32,21 +32,6 @@ class ConnectionState(enum.IntEnum): CLOSED = 5 # Connection closed. -class AsyncByteStream: - """ - The base interface for request and response bodies. - - Concrete implementations should subclass this class, and implement - the `\\__aiter__` method, and optionally the `aclose` method. - """ - - async def __aiter__(self) -> AsyncIterator[bytes]: - """ - Yield bytes representing the request or response body. - """ - yield b"" # pragma: nocover - - class AsyncHTTPTransport: """ The base interface for sending HTTP requests. diff --git a/httpcore/_sync/base.py b/httpcore/_sync/base.py index 1a1b8ae8..5e6e4fca 100644 --- a/httpcore/_sync/base.py +++ b/httpcore/_sync/base.py @@ -1,6 +1,6 @@ import enum from types import TracebackType -from typing import ContextManager, Iterable, Iterator, Tuple, Type +from typing import ContextManager, Iterable, Tuple, Type from .._types import URL, Headers, T @@ -32,21 +32,6 @@ class ConnectionState(enum.IntEnum): CLOSED = 5 # Connection closed. -class SyncByteStream: - """ - The base interface for request and response bodies. - - Concrete implementations should subclass this class, and implement - the `\\__iter__` method, and optionally the `close` method. - """ - - def __iter__(self) -> Iterator[bytes]: - """ - Yield bytes representing the request or response body. - """ - yield b"" # pragma: nocover - - class SyncHTTPTransport: """ The base interface for sending HTTP requests. @@ -73,7 +58,7 @@ def request( of (scheme, host, port, path). * **headers** - `Optional[List[Tuple[bytes, bytes]]]` - Any HTTP headers to send with the request. - * **stream** - `Optional[SyncByteStream]` - The body of the HTTP request. + * **stream** - `Optional[Iterable[bytes]]` - The body of the HTTP request. * **ext** - `Optional[dict]` - A dictionary of optional extensions. ** Returns:** @@ -83,7 +68,7 @@ def request( * **status_code** - `int` - The HTTP status code, such as `200`. * **headers** - `List[Tuple[bytes, bytes]]` - Any HTTP headers included on the response. - * **stream** - `SyncByteStream` - The body of the HTTP response. + * **stream** - `Iterable[bytes]` - The body of the HTTP response. * **ext** - `dict` - A dictionary of optional extensions. """ raise NotImplementedError() # pragma: nocover diff --git a/unasync.py b/unasync.py index 03983233..6617bb22 100755 --- a/unasync.py +++ b/unasync.py @@ -4,7 +4,6 @@ import sys SUBS = [ - ('AsyncIteratorByteStream', 'IteratorByteStream'), ('AsyncIterator', 'Iterator'), ('AsyncIterable', 'Iterable'), ('asynccontextmanager', 'contextmanager'), From f9563cf75d65325f7ab6904c5481a05577b8cc3e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 18 Nov 2020 11:28:45 +0000 Subject: [PATCH 15/17] Neater max_streams_semahore now that we have context-managed flow, rather than close callbacks --- httpcore/_async/http2.py | 27 ++++++++++----------------- httpcore/_backends/base.py | 11 +++++++++++ httpcore/_backends/sync.py | 11 +++++++++++ httpcore/_sync/http2.py | 27 ++++++++++----------------- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index be477190..786ff9a0 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -105,8 +105,7 @@ async def arequest( await self.send_connection_init(timeout) self.sent_connection_init = True - await self.max_streams_semaphore.acquire() - try: + async with self.max_streams_semaphore: try: stream_id = self.h2_state.get_next_available_stream_id() except NoAvailableStreamIDError: @@ -122,9 +121,6 @@ async def arequest( method, url, headers, stream, ext ) as response: yield response - except Exception: # noqa: PIE786 - await self.max_streams_semaphore.release() - raise async def send_connection_init(self, timeout: TimeoutDict) -> None: """ @@ -256,18 +252,15 @@ async def acknowledge_received_data( await self.socket.write(data_to_send, timeout) async def close_stream(self, stream_id: int) -> None: - try: - logger.trace("close_stream stream_id=%r", stream_id) - del self.streams[stream_id] - del self.events[stream_id] - - if not self.streams: - if self.state == ConnectionState.ACTIVE: - self.state = ConnectionState.IDLE - elif self.state == ConnectionState.FULL: - await self.aclose() - finally: - await self.max_streams_semaphore.release() + logger.trace("close_stream stream_id=%r", stream_id) + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams: + if self.state == ConnectionState.ACTIVE: + self.state = ConnectionState.IDLE + elif self.state == ConnectionState.FULL: + await self.aclose() class AsyncHTTP2Stream: diff --git a/httpcore/_backends/base.py b/httpcore/_backends/base.py index 1ca6e31b..a3027f07 100644 --- a/httpcore/_backends/base.py +++ b/httpcore/_backends/base.py @@ -96,6 +96,17 @@ class AsyncSemaphore: Abstracts away any asyncio-specific interfaces. """ + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.release() + async def acquire(self, timeout: float = None) -> None: raise NotImplementedError() # pragma: no cover diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py index 25e38ed0..92fde403 100644 --- a/httpcore/_backends/sync.py +++ b/httpcore/_backends/sync.py @@ -109,6 +109,17 @@ def __init__(self, max_value: int, exc_class: type) -> None: self.exc_class = exc_class self._semaphore = threading.Semaphore(max_value) + def __enter__(self) -> None: + self.acquire() + + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.release() + def acquire(self, timeout: float = None) -> None: if not self._semaphore.acquire(timeout=timeout): # type: ignore raise self.exc_class() diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 7e73a27d..795aa89d 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -105,8 +105,7 @@ def request( self.send_connection_init(timeout) self.sent_connection_init = True - self.max_streams_semaphore.acquire() - try: + with self.max_streams_semaphore: try: stream_id = self.h2_state.get_next_available_stream_id() except NoAvailableStreamIDError: @@ -122,9 +121,6 @@ def request( method, url, headers, stream, ext ) as response: yield response - except Exception: # noqa: PIE786 - self.max_streams_semaphore.release() - raise def send_connection_init(self, timeout: TimeoutDict) -> None: """ @@ -256,18 +252,15 @@ def acknowledge_received_data( self.socket.write(data_to_send, timeout) def close_stream(self, stream_id: int) -> None: - try: - logger.trace("close_stream stream_id=%r", stream_id) - del self.streams[stream_id] - del self.events[stream_id] - - if not self.streams: - if self.state == ConnectionState.ACTIVE: - self.state = ConnectionState.IDLE - elif self.state == ConnectionState.FULL: - self.close() - finally: - self.max_streams_semaphore.release() + logger.trace("close_stream stream_id=%r", stream_id) + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams: + if self.state == ConnectionState.ACTIVE: + self.state = ConnectionState.IDLE + elif self.state == ConnectionState.FULL: + self.close() class SyncHTTP2Stream: From d7e9b443e74b9fe6e2d9342729a0f620adb3d33d Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Mon, 30 Nov 2020 14:29:09 +0100 Subject: [PATCH 16/17] Update new tests from master --- tests/async_tests/test_interfaces.py | 9 ++++++--- tests/sync_tests/test_interfaces.py | 9 ++++++--- tests/test_threadsafety.py | 13 ++++++------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index bb4b5b90..76155095 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -483,7 +483,8 @@ async def test_cannot_connect_uds(backend: str) -> None: url = (b"http", b"localhost", None, b"/") async with httpcore.AsyncConnectionPool(backend=backend, uds=uds) as http: with pytest.raises(httpcore.ConnectError): - await http.arequest(method, url) + async with http.arequest(method, url): + pass # pragma: no cover @pytest.mark.skipif( @@ -501,7 +502,8 @@ async def test_connection_timeout_tcp(backend: str, server: Server) -> None: async with httpcore.AsyncConnectionPool(backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - await http.arequest(method, url, headers, ext=ext) + async with http.arequest(method, url, headers, ext=ext): + pass # pragma: no cover @pytest.mark.skipif( @@ -521,4 +523,5 @@ async def test_connection_timeout_uds( async with httpcore.AsyncConnectionPool(uds=uds, backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - await http.arequest(method, url, headers, ext=ext) + async with http.arequest(method, url, headers, ext=ext): + pass # pragma: no cover diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 58b677ef..9865108f 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -483,7 +483,8 @@ def test_cannot_connect_uds(backend: str) -> None: url = (b"http", b"localhost", None, b"/") with httpcore.SyncConnectionPool(backend=backend, uds=uds) as http: with pytest.raises(httpcore.ConnectError): - http.request(method, url) + with http.request(method, url): + pass # pragma: no cover @pytest.mark.skipif( @@ -501,7 +502,8 @@ def test_connection_timeout_tcp(backend: str, server: Server) -> None: with httpcore.SyncConnectionPool(backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - http.request(method, url, headers, ext=ext) + with http.request(method, url, headers, ext=ext): + pass # pragma: no cover @pytest.mark.skipif( @@ -521,4 +523,5 @@ def test_connection_timeout_uds( with httpcore.SyncConnectionPool(uds=uds, backend=backend) as http: with pytest.raises(httpcore.ConnectTimeout): - http.request(method, url, headers, ext=ext) + with http.request(method, url, headers, ext=ext): + pass # pragma: no cover diff --git a/tests/test_threadsafety.py b/tests/test_threadsafety.py index 81cdd95f..d491833f 100644 --- a/tests/test_threadsafety.py +++ b/tests/test_threadsafety.py @@ -1,4 +1,5 @@ import concurrent.futures +from typing import Iterable import pytest @@ -7,11 +8,8 @@ from .utils import Server -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - return b"".join(chunk for chunk in stream) - finally: - stream.close() +def read_body(stream: Iterable[bytes]) -> bytes: + return b"".join(chunk for chunk in stream) @pytest.mark.parametrize( @@ -30,8 +28,9 @@ def request(http: httpcore.SyncHTTPTransport) -> int: method = b"GET" url = (b"http", *server.netloc, b"/") headers = [server.host_header] - status_code, headers, stream, ext = http.request(method, url, headers) - read_body(stream) + with http.request(method, url, headers) as response: + status_code, _, stream, _ = response + read_body(stream) return status_code with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: From 824d4caa5629091ed4524da47cfeec751db4f773 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Mon, 30 Nov 2020 14:46:31 +0100 Subject: [PATCH 17/17] Coverage --- tests/async_tests/test_interfaces.py | 2 +- tests/async_tests/test_retries.py | 10 +++++----- tests/sync_tests/test_interfaces.py | 2 +- tests/sync_tests/test_retries.py | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 76155095..cea70266 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -470,7 +470,7 @@ async def test_cannot_connect_tcp(backend: str, url) -> None: method = b"GET" with pytest.raises(httpcore.ConnectError): async with http.arequest(method, url) as _: - pass + pass # pragma: no cover @pytest.mark.anyio diff --git a/tests/async_tests/test_retries.py b/tests/async_tests/test_retries.py index 6478ce19..35380493 100644 --- a/tests/async_tests/test_retries.py +++ b/tests/async_tests/test_retries.py @@ -58,11 +58,11 @@ async def test_no_retries(server: Server) -> None: with pytest.raises(httpcore.ConnectTimeout): async with http.arequest(method, url, headers) as response: - pass + pass # pragma: no cover with pytest.raises(httpcore.ConnectError): async with http.arequest(method, url, headers) as response: - pass + pass # pragma: no cover @pytest.mark.anyio @@ -118,11 +118,11 @@ async def test_retries_enabled(server: Server) -> None: backend.push(httpcore.ReadTimeout(), httpcore.NetworkError()) with pytest.raises(httpcore.ReadTimeout): async with http.arequest(method, url, headers) as response: - pass + pass # pragma: no cover with pytest.raises(httpcore.NetworkError): async with http.arequest(method, url, headers) as response: - pass + pass # pragma: no cover @pytest.mark.anyio @@ -149,4 +149,4 @@ async def test_retries_exceeded(server: Server) -> None: backend.push(httpcore.ConnectError(), httpcore.ConnectTimeout()) with pytest.raises(httpcore.ConnectTimeout): async with http.arequest(method, url, headers) as response: - pass + pass # pragma: no cover diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 9865108f..89b2de77 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -470,7 +470,7 @@ def test_cannot_connect_tcp(backend: str, url) -> None: method = b"GET" with pytest.raises(httpcore.ConnectError): with http.request(method, url) as _: - pass + pass # pragma: no cover diff --git a/tests/sync_tests/test_retries.py b/tests/sync_tests/test_retries.py index 961d4b66..c1deaa93 100644 --- a/tests/sync_tests/test_retries.py +++ b/tests/sync_tests/test_retries.py @@ -58,11 +58,11 @@ def test_no_retries(server: Server) -> None: with pytest.raises(httpcore.ConnectTimeout): with http.request(method, url, headers) as response: - pass + pass # pragma: no cover with pytest.raises(httpcore.ConnectError): with http.request(method, url, headers) as response: - pass + pass # pragma: no cover @@ -118,11 +118,11 @@ def test_retries_enabled(server: Server) -> None: backend.push(httpcore.ReadTimeout(), httpcore.NetworkError()) with pytest.raises(httpcore.ReadTimeout): with http.request(method, url, headers) as response: - pass + pass # pragma: no cover with pytest.raises(httpcore.NetworkError): with http.request(method, url, headers) as response: - pass + pass # pragma: no cover @@ -149,4 +149,4 @@ def test_retries_exceeded(server: Server) -> None: backend.push(httpcore.ConnectError(), httpcore.ConnectTimeout()) with pytest.raises(httpcore.ConnectTimeout): with http.request(method, url, headers) as response: - pass + pass # pragma: no cover