From 26e081962995a1640ac02c7602c20bfe6e84ada8 Mon Sep 17 00:00:00 2001 From: Johannes Date: Wed, 24 Mar 2021 20:33:22 +0000 Subject: [PATCH] Implement httpx 0.18 changes --- gen_sync.py | 6 +++-- httpx_caching/_async/_transport.py | 37 ++++++++++++++++-------------- httpx_caching/_models.py | 9 ++++---- httpx_caching/_serializer.py | 15 ++++++------ httpx_caching/_sync/_transport.py | 35 +++++++++++++++------------- httpx_caching/_utils.py | 15 ++++++++---- httpx_caching/_wrapper.py | 4 ++-- requirements.txt | 2 +- tests/_async/test_etag.py | 12 ++++++++-- tests/conftest.py | 2 +- tests/test_serialization.py | 2 +- 11 files changed, 80 insertions(+), 59 deletions(-) diff --git a/gen_sync.py b/gen_sync.py index 0182b43..727df87 100644 --- a/gen_sync.py +++ b/gen_sync.py @@ -18,19 +18,21 @@ fromdir="/_async/", todir="/_sync/", additional_replacements={ + "AsyncBaseTransport": "BaseTransport", "async_client": "client", "AsyncClient": "Client", "make_async_client": "make_client", "asyncio": "sync", "aclose": "close", + '"aclose"': '"close"', "aread": "read", "arun": "run", "aio_handler": "io_handler", - "arequest": "request", + "handle_async_request": "handle_request", + '"handle_async_request"': '"handle_request"', "aget": "get", "aset": "set", "adelete": "delete", - '"arequest"': '"request"', } ), ], diff --git a/httpx_caching/_async/_transport.py b/httpx_caching/_async/_transport.py index d81205d..559b1fd 100644 --- a/httpx_caching/_async/_transport.py +++ b/httpx_caching/_async/_transport.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Tuple +from typing import AsyncIterable, Iterable, Optional, Tuple import httpcore import httpx @@ -12,12 +12,12 @@ from httpx_caching._utils import ByteStreamWrapper, request_to_raw -class AsyncCachingTransport(httpcore.AsyncHTTPTransport): +class AsyncCachingTransport(httpx.AsyncBaseTransport): invalidating_methods = {"PUT", "PATCH", "DELETE"} def __init__( self, - transport: httpcore.AsyncHTTPTransport, + transport: httpx.AsyncBaseTransport, cache: AsyncDictCache = None, cache_etags: bool = True, heuristic: BaseHeuristic = None, @@ -38,14 +38,14 @@ def __init__( self.cacheable_status_codes = cacheable_status_codes self.cache_etags = cache_etags - async def arequest( + async def handle_async_request( self, method: bytes, url: RawURL, - headers: RawHeaders = None, - stream: httpcore.AsyncByteStream = None, - ext: dict = None, - ) -> Tuple[int, RawHeaders, httpcore.AsyncByteStream, dict]: + headers: RawHeaders, + stream: AsyncIterable[bytes], + extensions: dict, + ) -> Tuple[int, RawHeaders, AsyncIterable[bytes], dict]: request = httpx.Request( method=method, @@ -64,7 +64,7 @@ async def arequest( response, source = await caching_protocol.arun(self.aio_handler) - response.ext["from_cache"] = source == Source.CACHE + response.extensions["from_cache"] = source == Source.CACHE return response.to_raw() @multimethod @@ -84,28 +84,27 @@ async def _io_cache_delete(self, action: protocol.CacheDelete) -> None: @aio_handler.register async def _io_cache_set(self, action: protocol.CacheSet) -> Optional[Response]: - stream = action.response.stream - # TODO: we can probably just get rid of deferred? - if action.deferred and not isinstance(stream, httpcore.PlainByteStream): + if action.deferred: + # This is a response with a body, so we need to wait for it to be read before we can cache it return self.wrap_response_stream( action.key, action.response, action.vary_header_values ) else: stream = action.response.stream assert isinstance(stream, httpcore.PlainByteStream) - response_body = stream._content + # TODO: Are we needlessly recaching the body here? Is this just a header change? await self.cache.aset( action.key, action.response, action.vary_header_values, - response_body, + b"".join(stream), # type: ignore ) return None @aio_handler.register async def _io_make_request(self, action: protocol.MakeRequest) -> Response: args = request_to_raw(action.request) - raw_response = await self.transport.arequest(*args) # type: ignore + raw_response = await self.transport.handle_async_request(*args) # type: ignore return Response.from_raw(raw_response) @aio_handler.register @@ -114,13 +113,17 @@ async def _io_close_response_stream( ) -> None: async for _chunk in action.response.stream: # type: ignore pass - await action.response.stream.aclose() # type: ignore + aclose = action.response.extensions.get("aclose") + if aclose: + await aclose() # type: ignore return None def wrap_response_stream( self, key: str, response: Response, vary_header_values: dict ) -> Response: - wrapped_stream = ByteStreamWrapper(response.stream) + wrapped_stream = ByteStreamWrapper( + response.stream, response.extensions.get("aclose") + ) response.stream = wrapped_stream async def callback(response_body: bytes): diff --git a/httpx_caching/_models.py b/httpx_caching/_models.py index 1502ba9..ad60da1 100644 --- a/httpx_caching/_models.py +++ b/httpx_caching/_models.py @@ -1,7 +1,6 @@ import dataclasses -from typing import Union +from typing import AsyncIterable, Iterable, Union -from httpcore import AsyncByteStream, PlainByteStream, SyncByteStream from httpx import Headers @@ -13,15 +12,15 @@ class Response: status_code: int headers: Headers - stream: Union[SyncByteStream, AsyncByteStream] - ext: dict = dataclasses.field(default_factory=dict) + stream: Union[Iterable[bytes], AsyncIterable[bytes]] + extensions: dict = dataclasses.field(default_factory=dict) @classmethod def from_raw(cls, raw_response): values = list(raw_response) values[1] = Headers(values[1]) if isinstance(values[2], bytes): - values[2] = PlainByteStream(values[2]) + values[2] = [values[2]] return cls(*values) def to_raw(self): diff --git a/httpx_caching/_serializer.py b/httpx_caching/_serializer.py index 4630e70..cf50283 100644 --- a/httpx_caching/_serializer.py +++ b/httpx_caching/_serializer.py @@ -11,17 +11,18 @@ class Serializer(object): def dumps(self, response: Response, vary_header_data: dict, response_body: bytes): - # TODO: kludge while we put unserializable requests in ext - ext = response.ext.copy() - ext.pop("real_request", None) + extensions = response.extensions.copy() + extensions.pop("real_request", None) + extensions.pop("close", None) + extensions.pop("aclose", None) data = { "response": { "body": response_body, "headers": response.headers.raw, "status_code": response.status_code, - # TODO: Make sure we don't explode if there's something naughty in ext - "ext": ext, + # TODO: Make sure we don't explode if there's something naughty in extensions + "extensions": extensions, }, "vary": vary_header_data, } @@ -66,9 +67,9 @@ def prepare_response(self, cached_data: dict): status_code = cached_response["status_code"] headers = cached_response["headers"] stream = httpcore.PlainByteStream(cached_response["body"]) - ext = cached_response["ext"] + extensions = cached_response["extensions"] - response = Response.from_raw((status_code, headers, stream, ext)) + response = Response.from_raw((status_code, headers, stream, extensions)) if response.headers.get("transfer-encoding", "") == "chunked": response.headers.pop("transfer-encoding") diff --git a/httpx_caching/_sync/_transport.py b/httpx_caching/_sync/_transport.py index b0a42f4..93452a5 100644 --- a/httpx_caching/_sync/_transport.py +++ b/httpx_caching/_sync/_transport.py @@ -12,12 +12,12 @@ from httpx_caching._utils import ByteStreamWrapper, request_to_raw -class SyncCachingTransport(httpcore.SyncHTTPTransport): +class SyncCachingTransport(httpx.BaseTransport): invalidating_methods = {"PUT", "PATCH", "DELETE"} def __init__( self, - transport: httpcore.SyncHTTPTransport, + transport: httpx.BaseTransport, cache: SyncDictCache = None, cache_etags: bool = True, heuristic: BaseHeuristic = None, @@ -38,14 +38,14 @@ def __init__( self.cacheable_status_codes = cacheable_status_codes self.cache_etags = cache_etags - def request( + def handle_request( self, method: bytes, url: RawURL, - headers: RawHeaders = None, - stream: httpcore.SyncByteStream = None, - ext: dict = None, - ) -> Tuple[int, RawHeaders, httpcore.SyncByteStream, dict]: + headers: RawHeaders, + stream: Iterable[bytes], + extensions: dict, + ) -> Tuple[int, RawHeaders, Iterable[bytes], dict]: request = httpx.Request( method=method, @@ -64,7 +64,7 @@ def request( response, source = caching_protocol.run(self.io_handler) - response.ext["from_cache"] = source == Source.CACHE + response.extensions["from_cache"] = source == Source.CACHE return response.to_raw() @multimethod @@ -84,41 +84,44 @@ def _io_cache_delete(self, action: protocol.CacheDelete) -> None: @io_handler.register def _io_cache_set(self, action: protocol.CacheSet) -> Optional[Response]: - stream = action.response.stream - # TODO: we can probably just get rid of deferred? - if action.deferred and not isinstance(stream, httpcore.PlainByteStream): + if action.deferred: + # This is a response with a body, so we need to wait for it to be read before we can cache it return self.wrap_response_stream( action.key, action.response, action.vary_header_values ) else: stream = action.response.stream assert isinstance(stream, httpcore.PlainByteStream) - response_body = stream._content + # TODO: Are we needlessly recaching the body here? Is this just a header change? self.cache.set( action.key, action.response, action.vary_header_values, - response_body, + b"".join(stream), # type: ignore ) return None @io_handler.register def _io_make_request(self, action: protocol.MakeRequest) -> Response: args = request_to_raw(action.request) - raw_response = self.transport.request(*args) # type: ignore + raw_response = self.transport.handle_request(*args) # type: ignore return Response.from_raw(raw_response) @io_handler.register def _io_close_response_stream(self, action: protocol.CloseResponseStream) -> None: for _chunk in action.response.stream: # type: ignore pass - action.response.stream.close() # type: ignore + close = action.response.extensions.get("close") + if close: + close() # type: ignore return None def wrap_response_stream( self, key: str, response: Response, vary_header_values: dict ) -> Response: - wrapped_stream = ByteStreamWrapper(response.stream) + wrapped_stream = ByteStreamWrapper( + response.stream, response.extensions.get("close") + ) response.stream = wrapped_stream def callback(response_body: bytes): diff --git a/httpx_caching/_utils.py b/httpx_caching/_utils.py index 62dd421..a3d5c3d 100644 --- a/httpx_caching/_utils.py +++ b/httpx_caching/_utils.py @@ -1,9 +1,11 @@ import threading from typing import ( + AsyncIterable, AsyncIterator, Awaitable, Callable, Generator, + Iterable, Iterator, Optional, Tuple, @@ -13,16 +15,16 @@ import anyio import httpx -from httpcore import AsyncByteStream, SyncByteStream AsyncLock = anyio.create_lock SyncLock = threading.Lock -class ByteStreamWrapper(SyncByteStream, AsyncByteStream): +class ByteStreamWrapper: def __init__( self, - stream: Union[SyncByteStream, AsyncByteStream], + stream: Union[Iterable[bytes], AsyncIterable[bytes]], + stream_close: Optional[Callable], callback: Optional[Callable] = None, ) -> None: """ @@ -32,6 +34,7 @@ def __init__( print("wrapping", stream) self.stream = stream self.callback = callback or (lambda *args, **kwargs: None) + self.stream_close = stream_close self.buffer = bytearray() self.callback_called = False @@ -59,10 +62,12 @@ async def __aiter__(self) -> AsyncIterator[bytes]: await self.a_on_read_finish() def close(self) -> None: - self.stream.close() # type: ignore + if self.stream_close: + self.stream_close() # type: ignore async def aclose(self) -> None: - await self.stream.aclose() # type: ignore + if self.stream_close: + await self.stream_close() # type: ignore YieldType = TypeVar("YieldType") diff --git a/httpx_caching/_wrapper.py b/httpx_caching/_wrapper.py index bc1aaeb..9ccd699 100644 --- a/httpx_caching/_wrapper.py +++ b/httpx_caching/_wrapper.py @@ -1,4 +1,4 @@ -import httpcore +import httpx from httpx_caching._async._transport import AsyncCachingTransport from httpx_caching._sync._transport import SyncCachingTransport @@ -11,7 +11,7 @@ def CachingClient(client: AnyClient, *args, **kwargs) -> AnyClient: if "transport" not in kwargs: kwargs["transport"] = current_transport - is_async = isinstance(current_transport, httpcore.AsyncHTTPTransport) + is_async = isinstance(current_transport, httpx.AsyncBaseTransport) client._transport = (AsyncCachingTransport if is_async else SyncCachingTransport)( *args, **kwargs ) diff --git a/requirements.txt b/requirements.txt index 6203c27..561b0c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -httpx==0.16.* +git+git://github.com/encode/httpx@master#egg=httpx msgpack anyio multimethod diff --git a/tests/_async/test_etag.py b/tests/_async/test_etag.py index dee7a93..42c1389 100644 --- a/tests/_async/test_etag.py +++ b/tests/_async/test_etag.py @@ -20,7 +20,7 @@ def get_last_request(client): headers, stream, _ext, - ) = client._transport.transport.arequest.call_args[0] + ) = client._transport.transport.handle_async_request.call_args[0] return Request( method=method, url=url, @@ -39,7 +39,9 @@ async def async_client(mocker): async_client._transport = transport mocker.patch.object( - transport.transport, "arequest", wraps=transport.transport.arequest + transport.transport, + "handle_async_request", + wraps=transport.transport.handle_async_request, ) yield async_client @@ -89,6 +91,12 @@ async def test_etags_get_example(self, async_client, url): assert cache_hit(r2) assert raw_resp(r2) == raw_resp(r1) + # make the same request a 3rd time to make sure we don't mess anything up + # after a cache hit + r3 = await async_client.get(url + "etag") + assert cache_hit(r3) + assert raw_resp(r3) == raw_resp(r1) + # tell the server to change the etags of the response await async_client.get(url + "update_etag") diff --git a/tests/conftest.py b/tests/conftest.py index 8a3429d..5c7b03a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,7 +116,7 @@ def __call__(self, env, start_response): def cache_hit(resp): - return resp.ext["from_cache"] + return resp.extensions["from_cache"] @pytest.fixture(scope="session") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ca6d05d..e2a9f9b 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -20,7 +20,7 @@ def setup(self): "Cache-Control": "public", }, "status_code": 200, - "ext": {}, + "extensions": {}, }, "vary": {}, }