Skip to content

Commit

Permalink
Implement httpx 0.18 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
johtso committed Mar 24, 2021
1 parent e1bfeaf commit 26e0819
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 59 deletions.
6 changes: 4 additions & 2 deletions gen_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@
fromdir="/_async/",
todir="/_sync/",
additional_replacements={
"AsyncBaseTransport": "BaseTransport",
"async_client": "client",
"AsyncClient": "Client",
"make_async_client": "make_client",
"asyncio": "sync",
"aclose": "close",
'"aclose"': '"close"',
"aread": "read",
"arun": "run",
"aio_handler": "io_handler",
"arequest": "request",
"handle_async_request": "handle_request",
'"handle_async_request"': '"handle_request"',
"aget": "get",
"aset": "set",
"adelete": "delete",
'"arequest"': '"request"',
}
),
],
Expand Down
37 changes: 20 additions & 17 deletions httpx_caching/_async/_transport.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional, Tuple
from typing import AsyncIterable, Iterable, Optional, Tuple

import httpcore
import httpx
Expand All @@ -12,12 +12,12 @@
from httpx_caching._utils import ByteStreamWrapper, request_to_raw


class AsyncCachingTransport(httpcore.AsyncHTTPTransport):
class AsyncCachingTransport(httpx.AsyncBaseTransport):
invalidating_methods = {"PUT", "PATCH", "DELETE"}

def __init__(
self,
transport: httpcore.AsyncHTTPTransport,
transport: httpx.AsyncBaseTransport,
cache: AsyncDictCache = None,
cache_etags: bool = True,
heuristic: BaseHeuristic = None,
Expand All @@ -38,14 +38,14 @@ def __init__(
self.cacheable_status_codes = cacheable_status_codes
self.cache_etags = cache_etags

async def arequest(
async def handle_async_request(
self,
method: bytes,
url: RawURL,
headers: RawHeaders = None,
stream: httpcore.AsyncByteStream = None,
ext: dict = None,
) -> Tuple[int, RawHeaders, httpcore.AsyncByteStream, dict]:
headers: RawHeaders,
stream: AsyncIterable[bytes],
extensions: dict,
) -> Tuple[int, RawHeaders, AsyncIterable[bytes], dict]:

request = httpx.Request(
method=method,
Expand All @@ -64,7 +64,7 @@ async def arequest(

response, source = await caching_protocol.arun(self.aio_handler)

response.ext["from_cache"] = source == Source.CACHE
response.extensions["from_cache"] = source == Source.CACHE
return response.to_raw()

@multimethod
Expand All @@ -84,28 +84,27 @@ async def _io_cache_delete(self, action: protocol.CacheDelete) -> None:

@aio_handler.register
async def _io_cache_set(self, action: protocol.CacheSet) -> Optional[Response]:
stream = action.response.stream
# TODO: we can probably just get rid of deferred?
if action.deferred and not isinstance(stream, httpcore.PlainByteStream):
if action.deferred:
# This is a response with a body, so we need to wait for it to be read before we can cache it
return self.wrap_response_stream(
action.key, action.response, action.vary_header_values
)
else:
stream = action.response.stream
assert isinstance(stream, httpcore.PlainByteStream)
response_body = stream._content
# TODO: Are we needlessly recaching the body here? Is this just a header change?
await self.cache.aset(
action.key,
action.response,
action.vary_header_values,
response_body,
b"".join(stream), # type: ignore
)
return None

@aio_handler.register
async def _io_make_request(self, action: protocol.MakeRequest) -> Response:
args = request_to_raw(action.request)
raw_response = await self.transport.arequest(*args) # type: ignore
raw_response = await self.transport.handle_async_request(*args) # type: ignore
return Response.from_raw(raw_response)

@aio_handler.register
Expand All @@ -114,13 +113,17 @@ async def _io_close_response_stream(
) -> None:
async for _chunk in action.response.stream: # type: ignore
pass
await action.response.stream.aclose() # type: ignore
aclose = action.response.extensions.get("aclose")
if aclose:
await aclose() # type: ignore
return None

def wrap_response_stream(
self, key: str, response: Response, vary_header_values: dict
) -> Response:
wrapped_stream = ByteStreamWrapper(response.stream)
wrapped_stream = ByteStreamWrapper(
response.stream, response.extensions.get("aclose")
)
response.stream = wrapped_stream

async def callback(response_body: bytes):
Expand Down
9 changes: 4 additions & 5 deletions httpx_caching/_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import dataclasses
from typing import Union
from typing import AsyncIterable, Iterable, Union

from httpcore import AsyncByteStream, PlainByteStream, SyncByteStream
from httpx import Headers


Expand All @@ -13,15 +12,15 @@ class Response:

status_code: int
headers: Headers
stream: Union[SyncByteStream, AsyncByteStream]
ext: dict = dataclasses.field(default_factory=dict)
stream: Union[Iterable[bytes], AsyncIterable[bytes]]
extensions: dict = dataclasses.field(default_factory=dict)

@classmethod
def from_raw(cls, raw_response):
values = list(raw_response)
values[1] = Headers(values[1])
if isinstance(values[2], bytes):
values[2] = PlainByteStream(values[2])
values[2] = [values[2]]
return cls(*values)

def to_raw(self):
Expand Down
15 changes: 8 additions & 7 deletions httpx_caching/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@

class Serializer(object):
def dumps(self, response: Response, vary_header_data: dict, response_body: bytes):
# TODO: kludge while we put unserializable requests in ext
ext = response.ext.copy()
ext.pop("real_request", None)
extensions = response.extensions.copy()
extensions.pop("real_request", None)
extensions.pop("close", None)
extensions.pop("aclose", None)

data = {
"response": {
"body": response_body,
"headers": response.headers.raw,
"status_code": response.status_code,
# TODO: Make sure we don't explode if there's something naughty in ext
"ext": ext,
# TODO: Make sure we don't explode if there's something naughty in extensions
"extensions": extensions,
},
"vary": vary_header_data,
}
Expand Down Expand Up @@ -66,9 +67,9 @@ def prepare_response(self, cached_data: dict):
status_code = cached_response["status_code"]
headers = cached_response["headers"]
stream = httpcore.PlainByteStream(cached_response["body"])
ext = cached_response["ext"]
extensions = cached_response["extensions"]

response = Response.from_raw((status_code, headers, stream, ext))
response = Response.from_raw((status_code, headers, stream, extensions))

if response.headers.get("transfer-encoding", "") == "chunked":
response.headers.pop("transfer-encoding")
Expand Down
35 changes: 19 additions & 16 deletions httpx_caching/_sync/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from httpx_caching._utils import ByteStreamWrapper, request_to_raw


class SyncCachingTransport(httpcore.SyncHTTPTransport):
class SyncCachingTransport(httpx.BaseTransport):
invalidating_methods = {"PUT", "PATCH", "DELETE"}

def __init__(
self,
transport: httpcore.SyncHTTPTransport,
transport: httpx.BaseTransport,
cache: SyncDictCache = None,
cache_etags: bool = True,
heuristic: BaseHeuristic = None,
Expand All @@ -38,14 +38,14 @@ def __init__(
self.cacheable_status_codes = cacheable_status_codes
self.cache_etags = cache_etags

def request(
def handle_request(
self,
method: bytes,
url: RawURL,
headers: RawHeaders = None,
stream: httpcore.SyncByteStream = None,
ext: dict = None,
) -> Tuple[int, RawHeaders, httpcore.SyncByteStream, dict]:
headers: RawHeaders,
stream: Iterable[bytes],
extensions: dict,
) -> Tuple[int, RawHeaders, Iterable[bytes], dict]:

request = httpx.Request(
method=method,
Expand All @@ -64,7 +64,7 @@ def request(

response, source = caching_protocol.run(self.io_handler)

response.ext["from_cache"] = source == Source.CACHE
response.extensions["from_cache"] = source == Source.CACHE
return response.to_raw()

@multimethod
Expand All @@ -84,41 +84,44 @@ def _io_cache_delete(self, action: protocol.CacheDelete) -> None:

@io_handler.register
def _io_cache_set(self, action: protocol.CacheSet) -> Optional[Response]:
stream = action.response.stream
# TODO: we can probably just get rid of deferred?
if action.deferred and not isinstance(stream, httpcore.PlainByteStream):
if action.deferred:
# This is a response with a body, so we need to wait for it to be read before we can cache it
return self.wrap_response_stream(
action.key, action.response, action.vary_header_values
)
else:
stream = action.response.stream
assert isinstance(stream, httpcore.PlainByteStream)
response_body = stream._content
# TODO: Are we needlessly recaching the body here? Is this just a header change?
self.cache.set(
action.key,
action.response,
action.vary_header_values,
response_body,
b"".join(stream), # type: ignore
)
return None

@io_handler.register
def _io_make_request(self, action: protocol.MakeRequest) -> Response:
args = request_to_raw(action.request)
raw_response = self.transport.request(*args) # type: ignore
raw_response = self.transport.handle_request(*args) # type: ignore
return Response.from_raw(raw_response)

@io_handler.register
def _io_close_response_stream(self, action: protocol.CloseResponseStream) -> None:
for _chunk in action.response.stream: # type: ignore
pass
action.response.stream.close() # type: ignore
close = action.response.extensions.get("close")
if close:
close() # type: ignore
return None

def wrap_response_stream(
self, key: str, response: Response, vary_header_values: dict
) -> Response:
wrapped_stream = ByteStreamWrapper(response.stream)
wrapped_stream = ByteStreamWrapper(
response.stream, response.extensions.get("close")
)
response.stream = wrapped_stream

def callback(response_body: bytes):
Expand Down
15 changes: 10 additions & 5 deletions httpx_caching/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import threading
from typing import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Generator,
Iterable,
Iterator,
Optional,
Tuple,
Expand All @@ -13,16 +15,16 @@

import anyio
import httpx
from httpcore import AsyncByteStream, SyncByteStream

AsyncLock = anyio.create_lock
SyncLock = threading.Lock


class ByteStreamWrapper(SyncByteStream, AsyncByteStream):
class ByteStreamWrapper:
def __init__(
self,
stream: Union[SyncByteStream, AsyncByteStream],
stream: Union[Iterable[bytes], AsyncIterable[bytes]],
stream_close: Optional[Callable],
callback: Optional[Callable] = None,
) -> None:
"""
Expand All @@ -32,6 +34,7 @@ def __init__(
print("wrapping", stream)
self.stream = stream
self.callback = callback or (lambda *args, **kwargs: None)
self.stream_close = stream_close

self.buffer = bytearray()
self.callback_called = False
Expand Down Expand Up @@ -59,10 +62,12 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
await self.a_on_read_finish()

def close(self) -> None:
self.stream.close() # type: ignore
if self.stream_close:
self.stream_close() # type: ignore

async def aclose(self) -> None:
await self.stream.aclose() # type: ignore
if self.stream_close:
await self.stream_close() # type: ignore


YieldType = TypeVar("YieldType")
Expand Down
4 changes: 2 additions & 2 deletions httpx_caching/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import httpcore
import httpx

from httpx_caching._async._transport import AsyncCachingTransport
from httpx_caching._sync._transport import SyncCachingTransport
Expand All @@ -11,7 +11,7 @@ def CachingClient(client: AnyClient, *args, **kwargs) -> AnyClient:
if "transport" not in kwargs:
kwargs["transport"] = current_transport

is_async = isinstance(current_transport, httpcore.AsyncHTTPTransport)
is_async = isinstance(current_transport, httpx.AsyncBaseTransport)
client._transport = (AsyncCachingTransport if is_async else SyncCachingTransport)(
*args, **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
httpx==0.16.*
git+git://github.com/encode/httpx@master#egg=httpx
msgpack
anyio
multimethod
12 changes: 10 additions & 2 deletions tests/_async/test_etag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_last_request(client):
headers,
stream,
_ext,
) = client._transport.transport.arequest.call_args[0]
) = client._transport.transport.handle_async_request.call_args[0]
return Request(
method=method,
url=url,
Expand All @@ -39,7 +39,9 @@ async def async_client(mocker):
async_client._transport = transport

mocker.patch.object(
transport.transport, "arequest", wraps=transport.transport.arequest
transport.transport,
"handle_async_request",
wraps=transport.transport.handle_async_request,
)

yield async_client
Expand Down Expand Up @@ -89,6 +91,12 @@ async def test_etags_get_example(self, async_client, url):
assert cache_hit(r2)
assert raw_resp(r2) == raw_resp(r1)

# make the same request a 3rd time to make sure we don't mess anything up
# after a cache hit
r3 = await async_client.get(url + "etag")
assert cache_hit(r3)
assert raw_resp(r3) == raw_resp(r1)

# tell the server to change the etags of the response
await async_client.get(url + "update_etag")

Expand Down
Loading

0 comments on commit 26e0819

Please sign in to comment.