Skip to content

Commit

Permalink
Implement HTTPX Transport API
Browse files Browse the repository at this point in the history
  • Loading branch information
lundberg committed Apr 8, 2021
1 parent b27ad8e commit 91abc18
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 66 deletions.
11 changes: 11 additions & 0 deletions respx/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import TYPE_CHECKING, ClassVar, Dict, List, Type
from unittest import mock

from httpcore import AsyncIteratorByteStream, IteratorByteStream

from .models import PassThrough, decode_request, encode_response
from .transports import MockTransport, TryTransport

Expand Down Expand Up @@ -150,6 +152,8 @@ def mock(self, *args, **kwargs):
request = cls.to_httpx_request(**kwargs)
request, kwargs = cls.prepare(request, **kwargs)
response = cls._send(request, instance=self, target_spec=spec, **kwargs)
status_code, headers, stream, extensions = response
response = (status_code, headers, IteratorByteStream(stream), extensions)
return response

async def amock(self, *args, **kwargs):
Expand All @@ -159,6 +163,13 @@ async def amock(self, *args, **kwargs):
response = cls._send(request, instance=self, target_spec=spec, **kwargs)
if inspect.isawaitable(response):
response = await response
status_code, headers, stream, extensions = response
response = (
status_code,
headers,
AsyncIteratorByteStream(stream),
extensions,
)
return response

return amock if inspect.iscoroutinefunction(spec) else mock
Expand Down
8 changes: 4 additions & 4 deletions respx/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def encode_response(response: httpx.Response) -> Response:
response.status_code,
response.headers.raw,
response.stream,
response.ext,
response.extensions,
)


Expand All @@ -58,7 +58,7 @@ def clone_response(response: httpx.Response, request: httpx.Request) -> httpx.Re
headers=response.headers,
stream=response.stream,
request=request,
ext=dict(response.ext),
extensions=dict(response.extensions),
)
if isinstance(response.stream, Iterable):
response.read() # Pre-read stream for easier call stats usage
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
if content_type:
self.headers["Content-Type"] = content_type
if http_version:
self.ext["http_version"] = http_version
self.extensions["http_version"] = http_version


class Route:
Expand Down Expand Up @@ -340,7 +340,7 @@ def _resolve_side_effect(
self,
origin=(
Error("Mock Error", request=request)
if issubclass(Error, httpx.HTTPError)
if issubclass(Error, httpx.RequestError)
else Error()
),
)
Expand Down
68 changes: 38 additions & 30 deletions respx/transports.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union

from httpcore import (
AsyncByteStream,
AsyncHTTPTransport,
SyncByteStream,
SyncHTTPTransport,
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
Iterable,
List,
Optional,
Type,
Union,
)

from httpx import AsyncBaseTransport, BaseTransport

from .models import PassThrough, decode_request, encode_response
from .types import URL, AsyncResponse, Headers, RequestHandler, SyncResponse

if TYPE_CHECKING:
from .router import Router # pragma: nocover


class MockTransport(SyncHTTPTransport, AsyncHTTPTransport):
class MockTransport(BaseTransport, AsyncBaseTransport):
_handler: Optional[RequestHandler]
_router: Optional["Router"]

Expand All @@ -40,13 +44,13 @@ def __init__(
def handler(self) -> RequestHandler:
return self._handler or self._router.handler

def request(
def handle_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: SyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: Iterable[bytes],
extensions: dict,
) -> SyncResponse:
raw_request = (method, url, headers, stream)
request = decode_request(raw_request)
Expand All @@ -60,13 +64,13 @@ def request(
raw_response = encode_response(response)
return raw_response # type: ignore

async def arequest(
async def handle_async_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: AsyncIterable[bytes],
extensions: dict,
) -> AsyncResponse:
raw_request = (method, url, headers, stream)
request = decode_request(raw_request)
Expand All @@ -93,25 +97,27 @@ async def __aexit__(self, *args: Any) -> None:
self.__exit__(*args)


class TryTransport(SyncHTTPTransport, AsyncHTTPTransport):
class TryTransport(BaseTransport, AsyncBaseTransport):
def __init__(
self, transports: List[Union[SyncHTTPTransport, AsyncHTTPTransport]]
self, transports: List[Union[BaseTransport, AsyncBaseTransport]]
) -> None:
self.transports = transports

def request(
def handle_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: SyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: Iterable[bytes],
extensions: dict,
) -> SyncResponse:
error: Exception = None
for transport in self.transports:
try:
assert isinstance(transport, SyncHTTPTransport)
return transport.request(method, url, headers, stream, ext)
assert isinstance(transport, BaseTransport)
return transport.handle_request(
method, url, headers, stream, extensions
)
except PassThrough as pass_through:
stream = pass_through.request.stream # type: ignore
except AssertionError:
Expand All @@ -120,19 +126,21 @@ def request(
error = e
raise error

async def arequest(
async def handle_async_request(
self,
method: bytes,
url: URL,
headers: Headers = None,
stream: AsyncByteStream = None,
ext: dict = None,
headers: Headers,
stream: AsyncIterable[bytes],
extensions: dict,
) -> AsyncResponse:
error: Exception = None
for transport in self.transports:
try:
assert isinstance(transport, AsyncHTTPTransport)
return await transport.arequest(method, url, headers, stream, ext)
assert isinstance(transport, AsyncBaseTransport)
return await transport.handle_async_request(
method, url, headers, stream, extensions
)
except PassThrough as pass_through:
stream = pass_through.request.stream # type: ignore
except AssertionError:
Expand Down
5 changes: 2 additions & 3 deletions respx/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)

import httpx
from httpcore import AsyncByteStream, SyncByteStream

URL = Tuple[
bytes, # scheme
Expand All @@ -35,13 +34,13 @@
SyncResponse = Tuple[
int, # status code
Headers,
SyncByteStream, # body
Iterable[bytes], # body
dict, # ext
]
AsyncResponse = Tuple[
int, # status code
Headers,
AsyncByteStream, # body
AsyncIterable[bytes], # body
dict, # ext
]
Response = Tuple[
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,8 @@
include_package_data=True,
zip_safe=False,
python_requires=">=3.6",
install_requires=["httpx>=0.15"],
# install_requires=["httpx>=0.15"],
install_requires=[
"httpx @ https://github.com/encode/httpx/archive/refs/heads/master.zip"
],
)
12 changes: 6 additions & 6 deletions tests/test_mock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from contextlib import ExitStack as does_not_raise

import httpcore
import httpx
import pytest

Expand Down Expand Up @@ -456,8 +455,9 @@ async def test_assert_all_mocked(client, assert_all_mocked, raises):
assert respx_mock.calls.call_count == 0


@pytest.mark.xfail(strict=True)
@pytest.mark.asyncio
async def test_asgi():
async def test_asgi(): # pragma: nocover
from respx.mocks import HTTPCoreMocker

try:
Expand Down Expand Up @@ -590,8 +590,8 @@ class Hamspam(Mocker):


def test_sync_httpx_mocker():
class TestTransport(httpcore.SyncHTTPTransport):
def request(self, *args, **kwargs):
class TestTransport(httpx.BaseTransport):
def handle_request(self, *args, **kwargs):
raise RuntimeError("would pass through")

client = httpx.Client(transport=TestTransport())
Expand Down Expand Up @@ -619,8 +619,8 @@ def test(respx_mock):

@pytest.mark.asyncio
async def test_async_httpx_mocker():
class TestTransport(httpcore.AsyncHTTPTransport):
async def arequest(self, *args, **kwargs):
class TestTransport(httpx.AsyncBaseTransport):
async def handle_async_request(self, *args, **kwargs):
raise RuntimeError("would pass through")

client = httpx.AsyncClient(transport=TestTransport())
Expand Down
26 changes: 4 additions & 22 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,8 @@ async def backend_test(backend):
assert _request.url == url
assert _response.status_code == get_response.status_code == 202
assert _response.content == get_response.content == b"get"
assert {
_response.status_code,
tuple(_response.headers.raw),
_response.stream,
tuple(_response.ext.items()),
} == {
get_response.status_code,
tuple(get_response.headers.raw),
get_response.stream,
tuple(get_response.ext.items()),
}
assert tuple(_response.headers.raw) == tuple(get_response.headers.raw)
assert _response.extensions == get_response.extensions
assert id(_response) != id(get_response)

_request, _response = foobar2.calls[-1]
Expand All @@ -72,17 +63,8 @@ async def backend_test(backend):
assert _request.url == url
assert _response.status_code == del_response.status_code == 200
assert _response.content == del_response.content == b"del"
assert {
_response.status_code,
tuple(_response.headers.raw),
_response.stream,
tuple(_response.ext.items()),
} == {
del_response.status_code,
tuple(del_response.headers.raw),
del_response.stream,
tuple(del_response.ext.items()),
}
assert tuple(_response.headers.raw) == tuple(del_response.headers.raw)
assert _response.extensions == del_response.extensions
assert id(_response) != id(del_response)

assert respx.calls.call_count == 2
Expand Down

0 comments on commit 91abc18

Please sign in to comment.