Skip to content

Commit

Permalink
Implement HTTPX Transport API (#142)
Browse files Browse the repository at this point in the history
* Implement HTTPX Transport API
* Require httpx 0.18.0
  • Loading branch information
lundberg authored Apr 27, 2021
1 parent b27ad8e commit 20cad3d
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 75 deletions.
26 changes: 23 additions & 3 deletions respx/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import TYPE_CHECKING, ClassVar, Dict, List, Type
from unittest import mock

from httpcore import AsyncIteratorByteStream, IteratorByteStream

from .models import PassThrough, decode_request, encode_response
from .transports import MockTransport, TryTransport

Expand Down Expand Up @@ -150,6 +152,8 @@ def mock(self, *args, **kwargs):
request = cls.to_httpx_request(**kwargs)
request, kwargs = cls.prepare(request, **kwargs)
response = cls._send(request, instance=self, target_spec=spec, **kwargs)
status_code, headers, stream, extensions = response
response = (status_code, headers, IteratorByteStream(stream), extensions)
return response

async def amock(self, *args, **kwargs):
Expand All @@ -159,14 +163,25 @@ async def amock(self, *args, **kwargs):
response = cls._send(request, instance=self, target_spec=spec, **kwargs)
if inspect.isawaitable(response):
response = await response
status_code, headers, stream, extensions = response
response = (
status_code,
headers,
AsyncIteratorByteStream(stream),
extensions,
)
return response

return amock if inspect.iscoroutinefunction(spec) else mock

@classmethod
def _merge_args_and_kwargs(cls, argspec, args, kwargs):
arg_names = argspec.args[1:] # Skip self
new_kwargs = dict(zip(arg_names[-len(argspec.defaults) :], argspec.defaults))
new_kwargs = (
dict(zip(arg_names[-len(argspec.defaults) :], argspec.defaults))
if argspec.defaults
else dict()
)
new_kwargs.update(zip(arg_names, args))
new_kwargs.update(kwargs)
return new_kwargs
Expand Down Expand Up @@ -216,7 +231,7 @@ class HTTPCoreMocker(AbstractRequestMocker):
"httpcore._async.connection_pool.AsyncConnectionPool",
"httpcore._async.http_proxy.AsyncHTTPProxy",
]
target_methods = ["request", "arequest"]
target_methods = ["handle_request", "handle_async_request"]

@classmethod
def prepare(cls, httpx_request, **kwargs):
Expand All @@ -241,7 +256,12 @@ def to_httpx_request(cls, **kwargs):
"""
Create a `HTTPX` request from transport request args.
"""
request = (kwargs["method"], kwargs["url"], kwargs["headers"], kwargs["stream"])
request = (
kwargs["method"],
kwargs["url"],
kwargs.get("headers"),
kwargs.get("stream"),
)
httpx_request = decode_request(request)
return httpx_request

Expand Down
12 changes: 7 additions & 5 deletions respx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def encode_response(response: httpx.Response) -> Response:
response.status_code,
response.headers.raw,
response.stream,
response.ext,
response.extensions,
)


Expand All @@ -58,7 +58,7 @@ def clone_response(response: httpx.Response, request: httpx.Request) -> httpx.Re
headers=response.headers,
stream=response.stream,
request=request,
ext=dict(response.ext),
extensions=dict(response.extensions),
)
if isinstance(response.stream, Iterable):
response.read() # Pre-read stream for easier call stats usage
Expand Down Expand Up @@ -101,7 +101,9 @@ def __init__(
http_version: Optional[str] = None,
**kwargs: Any,
) -> None:
if callable(content) or isinstance(content, (dict, Exception)):
if not isinstance(content, (str, bytes)) and (
callable(content) or isinstance(content, (dict, Exception))
):
raise TypeError(
f"MockResponse content can only be str, bytes or byte stream"
f"got {content!r}. Please use json=... or side effects."
Expand All @@ -112,7 +114,7 @@ def __init__(
if content_type:
self.headers["Content-Type"] = content_type
if http_version:
self.ext["http_version"] = http_version
self.extensions["http_version"] = http_version.encode("ascii")


class Route:
Expand Down Expand Up @@ -340,7 +342,7 @@ def _resolve_side_effect(
self,
origin=(
Error("Mock Error", request=request)
if issubclass(Error, httpx.HTTPError)
if issubclass(Error, httpx.RequestError)
else Error()
),
)
Expand Down
57 changes: 28 additions & 29 deletions respx/transports.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union

from httpcore import (
AsyncByteStream,
AsyncHTTPTransport,
SyncByteStream,
SyncHTTPTransport,
)
from httpx import AsyncBaseTransport, AsyncByteStream, BaseTransport, SyncByteStream

from .models import PassThrough, decode_request, encode_response
from .types import URL, AsyncResponse, Headers, RequestHandler, SyncResponse
Expand All @@ -15,7 +10,7 @@
from .router import Router # pragma: nocover


class MockTransport(SyncHTTPTransport, AsyncHTTPTransport):
class MockTransport(BaseTransport, AsyncBaseTransport):
_handler: Optional[RequestHandler]
_router: Optional["Router"]

Expand All @@ -40,13 +35,13 @@ def __init__(
def handler(self) -> RequestHandler:
return self._handler or self._router.handler

def request(
def handle_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: SyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: SyncByteStream,
extensions: dict,
) -> SyncResponse:
raw_request = (method, url, headers, stream)
request = decode_request(raw_request)
Expand All @@ -60,13 +55,13 @@ def request(
raw_response = encode_response(response)
return raw_response # type: ignore

async def arequest(
async def handle_async_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: AsyncByteStream,
extensions: dict,
) -> AsyncResponse:
raw_request = (method, url, headers, stream)
request = decode_request(raw_request)
Expand All @@ -93,25 +88,27 @@ async def __aexit__(self, *args: Any) -> None:
self.__exit__(*args)


class TryTransport(SyncHTTPTransport, AsyncHTTPTransport):
class TryTransport(BaseTransport, AsyncBaseTransport):
def __init__(
self, transports: List[Union[SyncHTTPTransport, AsyncHTTPTransport]]
self, transports: List[Union[BaseTransport, AsyncBaseTransport]]
) -> None:
self.transports = transports

def request(
def handle_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: SyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: SyncByteStream,
extensions: dict,
) -> SyncResponse:
error: Exception = None
for transport in self.transports:
try:
assert isinstance(transport, SyncHTTPTransport)
return transport.request(method, url, headers, stream, ext)
assert isinstance(transport, BaseTransport)
return transport.handle_request(
method, url, headers, stream, extensions
)
except PassThrough as pass_through:
stream = pass_through.request.stream # type: ignore
except AssertionError:
Expand All @@ -120,19 +117,21 @@ def request(
error = e
raise error

async def arequest(
async def handle_async_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: AsyncByteStream,
extensions: dict,
) -> AsyncResponse:
error: Exception = None
for transport in self.transports:
try:
assert isinstance(transport, AsyncHTTPTransport)
return await transport.arequest(method, url, headers, stream, ext)
assert isinstance(transport, AsyncBaseTransport)
return await transport.handle_async_request(
method, url, headers, stream, extensions
)
except PassThrough as pass_through:
stream = pass_through.request.stream # type: ignore
except AssertionError:
Expand Down
9 changes: 3 additions & 6 deletions respx/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import (
Any,
AsyncIterable,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Expand All @@ -16,7 +14,6 @@
)

import httpx
from httpcore import AsyncByteStream, SyncByteStream

URL = Tuple[
bytes, # scheme
Expand All @@ -25,7 +22,7 @@
bytes, # path
]
Headers = List[Tuple[bytes, bytes]]
ByteStream = Union[Iterable[bytes], AsyncIterable[bytes]]
ByteStream = Union[httpx.SyncByteStream, httpx.AsyncByteStream]
Request = Tuple[
bytes, # http method
URL,
Expand All @@ -35,13 +32,13 @@
SyncResponse = Tuple[
int, # status code
Headers,
SyncByteStream, # body
httpx.SyncByteStream, # body
dict, # ext
]
AsyncResponse = Tuple[
int, # status code
Headers,
AsyncByteStream, # body
httpx.AsyncByteStream, # body
dict, # ext
]
Response = Tuple[
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@
include_package_data=True,
zip_safe=False,
python_requires=">=3.6",
install_requires=["httpx>=0.15"],
install_requires=["httpx>=0.18.0"],
)
15 changes: 8 additions & 7 deletions tests/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ async def raw_stream():
yield b"foo"
yield b"bar"

stream = httpcore.AsyncIteratorByteStream(raw_stream())
request = respx_mock.get("https://foo.bar/").mock(
return_value=httpx.Response(202, stream=raw_stream())
return_value=httpx.Response(202, stream=stream)
)

response = await client.get("https://foo.bar/")
Expand All @@ -139,7 +140,6 @@ async def raw_stream():
assert response.content == b"foobar"
assert respx.calls.call_count == 0
assert respx_mock.calls.call_count == 1
assert await respx_mock.calls.last.response.aread() == b"" # TODO: Exhausted!

with pytest.raises(AssertionError, match="not mocked"):
httpx.post("https://foo.bar/")
Expand Down Expand Up @@ -456,8 +456,9 @@ async def test_assert_all_mocked(client, assert_all_mocked, raises):
assert respx_mock.calls.call_count == 0


@pytest.mark.xfail(strict=False)
@pytest.mark.asyncio
async def test_asgi():
async def test_asgi(): # pragma: nocover
from respx.mocks import HTTPCoreMocker

try:
Expand Down Expand Up @@ -590,8 +591,8 @@ class Hamspam(Mocker):


def test_sync_httpx_mocker():
class TestTransport(httpcore.SyncHTTPTransport):
def request(self, *args, **kwargs):
class TestTransport(httpx.BaseTransport):
def handle_request(self, *args, **kwargs):
raise RuntimeError("would pass through")

client = httpx.Client(transport=TestTransport())
Expand Down Expand Up @@ -619,8 +620,8 @@ def test(respx_mock):

@pytest.mark.asyncio
async def test_async_httpx_mocker():
class TestTransport(httpcore.AsyncHTTPTransport):
async def arequest(self, *args, **kwargs):
class TestTransport(httpx.AsyncBaseTransport):
async def handle_async_request(self, *args, **kwargs):
raise RuntimeError("would pass through")

client = httpx.AsyncClient(transport=TestTransport())
Expand Down
26 changes: 4 additions & 22 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,8 @@ async def backend_test(backend):
assert _request.url == url
assert _response.status_code == get_response.status_code == 202
assert _response.content == get_response.content == b"get"
assert {
_response.status_code,
tuple(_response.headers.raw),
_response.stream,
tuple(_response.ext.items()),
} == {
get_response.status_code,
tuple(get_response.headers.raw),
get_response.stream,
tuple(get_response.ext.items()),
}
assert tuple(_response.headers.raw) == tuple(get_response.headers.raw)
assert _response.extensions == get_response.extensions
assert id(_response) != id(get_response)

_request, _response = foobar2.calls[-1]
Expand All @@ -72,17 +63,8 @@ async def backend_test(backend):
assert _request.url == url
assert _response.status_code == del_response.status_code == 200
assert _response.content == del_response.content == b"del"
assert {
_response.status_code,
tuple(_response.headers.raw),
_response.stream,
tuple(_response.ext.items()),
} == {
del_response.status_code,
tuple(del_response.headers.raw),
del_response.stream,
tuple(del_response.ext.items()),
}
assert tuple(_response.headers.raw) == tuple(del_response.headers.raw)
assert _response.extensions == del_response.extensions
assert id(_response) != id(del_response)

assert respx.calls.call_count == 2
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ async def test_httpcore_request(url, port):
router.get(url) % dict(text="foobar")

with httpcore.SyncConnectionPool() as http:
(status_code, headers, stream, ext) = http.request(
(status_code, headers, stream, ext) = http.handle_request(
method=b"GET", url=(b"https", b"foo.bar", port, b"/")
)

body = b"".join([chunk for chunk in stream])
assert body == b"foobar"

async with httpcore.AsyncConnectionPool() as http:
(status_code, headers, stream, ext) = await http.arequest(
(status_code, headers, stream, ext) = await http.handle_async_request(
method=b"GET", url=(b"https", b"foo.bar", port, b"/")
)

Expand Down

0 comments on commit 20cad3d

Please sign in to comment.