Skip to content

Commit

Permalink
Refactor, run tests on trio
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Jun 13, 2020
1 parent 47384a6 commit e6d6c6c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 100 deletions.
183 changes: 92 additions & 91 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
import contextlib
from typing import (
TYPE_CHECKING,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
)

import httpcore
import sniffio

from .._content_streams import AsyncIteratorStream, ByteStream
from .._utils import warn_deprecated

if typing.TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING:
import asyncio
import trio

Event = typing.Union[asyncio.Event, trio.Event]
Event = Union[asyncio.Event, trio.Event]


def create_event() -> "Event":
Expand All @@ -25,19 +35,18 @@ def create_event() -> "Event":
return asyncio.Event()


async def create_background_task(async_fn: typing.Callable) -> typing.Callable:
async def create_background_task(
async_fn: Callable[[], Awaitable[None]]
) -> Callable[[], Awaitable[None]]:
if sniffio.current_async_library() == "trio":
import trio

nursery_manager = trio.open_nursery()
nursery = await nursery_manager.__aenter__()
nursery.start_soon(async_fn)

async def aclose(exc: Exception = None) -> None:
if exc is not None:
await nursery_manager.__aexit__(type(exc), exc, exc.__traceback__)
else:
await nursery_manager.__aexit__(None, None, None)
async def aclose() -> None:
await nursery_manager.__aexit__(None, None, None)

return aclose

Expand All @@ -47,52 +56,63 @@ async def aclose(exc: Exception = None) -> None:
loop = asyncio.get_event_loop()
task = loop.create_task(async_fn())

async def aclose(exc: Exception = None) -> None:
if not task.done():
task.cancel()
async def aclose() -> None:
task.cancel()
# Task must be awaited in all cases to avoid debug warnings.
with contextlib.suppress(asyncio.CancelledError):
await task

return aclose


def create_channel(
capacity: int,
) -> typing.Tuple[
typing.Callable[[], typing.Awaitable[bytes]],
typing.Callable[[bytes], typing.Awaitable[None]],
) -> Tuple[
Callable[[bytes], Awaitable[None]],
Callable[[], Awaitable[None]],
Callable[[], AsyncIterator[bytes]],
]:
"""
Create an in-memory channel to pass data chunks between tasks.
* `produce()`: send data through the channel, blocking if necessary.
* `consume()`: iterate over data in the channel.
* `aclose_produce()`: mark that no more data will be produced, causing
`consume()` to flush remaining data chunks then stop.
"""
if sniffio.current_async_library() == "trio":
import trio

send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)
return receive_channel.receive, send_channel.send

async def consume() -> AsyncIterator[bytes]:
async for chunk in receive_channel:
yield chunk

return send_channel.send, send_channel.aclose, consume

else:
import asyncio

queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
return queue.get, queue.put
produce_closed = False

async def produce(chunk: bytes) -> None:
assert not produce_closed
await queue.put(chunk)

async def run_until_first_complete(*async_fns: typing.Callable) -> None:
if sniffio.current_async_library() == "trio":
import trio

async with trio.open_nursery() as nursery:
async def aclose_produce() -> None:
nonlocal produce_closed
await queue.put(b"") # Make sure (*) doesn't block forever.
produce_closed = True

async def run(async_fn: typing.Callable) -> None:
await async_fn()
nursery.cancel_scope.cancel()

for async_fn in async_fns:
nursery.start_soon(run, async_fn)
async def consume() -> AsyncIterator[bytes]:
while True:
if produce_closed and queue.empty():
break
yield await queue.get() # (*)

else:
import asyncio

coros = [async_fn() for async_fn in async_fns]
done, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
return produce, aclose_produce, consume


class ASGITransport(httpcore.AsyncHTTPTransport):
Expand Down Expand Up @@ -148,6 +168,9 @@ async def request(
stream: httpcore.AsyncByteStream = None,
timeout: Dict[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
Expand All @@ -163,19 +186,20 @@ async def request(
"client": self.client,
"root_path": self.root_path,
}
status_code = None
response_headers = None
consume_response_body_chunk, produce_response_body_chunk = create_channel(1)

# Request.
request_body = stream.__aiter__()
request_complete = False
response_started = create_event()
response_complete = create_event()
app_crashed = create_event()
app_exception: typing.Optional[Exception] = None

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream
# Response.
response_headers: Optional[List[Tuple[bytes, bytes]]] = None
status_code: Optional[int] = None
response_started_or_app_crashed = create_event()
produce_body, aclose_body, consume_body = create_channel(1)
response_complete = create_event()

request_body_chunks = stream.__aiter__()
# Error handling.
app_exception: Optional[Exception] = None

async def receive() -> dict:
nonlocal request_complete
Expand All @@ -185,31 +209,31 @@ async def receive() -> dict:
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
body = await request_body.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers

if message["type"] == "http.response.start":
assert not response_started.is_set()

# App is sending the response headers.
assert not response_started_or_app_crashed.is_set()
status_code = message["status"]
response_headers = message.get("headers", [])
response_started.set()
response_started_or_app_crashed.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
# App is sending a chunk of the response body.
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
await produce_response_body_chunk(body)
await produce_body(body)

if not more_body:
await aclose_body()
response_complete.set()

async def run_app() -> None:
Expand All @@ -218,51 +242,28 @@ async def run_app() -> None:
await self.app(scope, receive, send)
except Exception as exc:
app_exception = exc
app_crashed.set()
response_started_or_app_crashed.set()
await aclose_body() # Stop response body consumer once flushed (*).

aclose_app = await create_background_task(run_app)

await run_until_first_complete(app_crashed.wait, response_started.wait)
async def aiter_response_body() -> AsyncIterator[bytes]:
async for chunk in consume_body(): # (*)
yield chunk

if app_crashed.is_set():
assert app_exception is not None
await aclose_app(app_exception)
if self.raise_app_exceptions or not response_started.is_set():
if app_exception is not None and self.raise_app_exceptions:
raise app_exception

assert response_started.is_set()
assert status_code is not None
assert response_headers is not None

async def aiter_response_body_chunks() -> typing.AsyncIterator[bytes]:
chunk = b""
aclose = await create_background_task(run_app)

async def consume_chunk() -> None:
nonlocal chunk
chunk = await consume_response_body_chunk()
await response_started_or_app_crashed.wait()

while True:
await run_until_first_complete(
app_crashed.wait, consume_chunk, response_complete.wait
)

if app_crashed.is_set():
assert app_exception is not None
if self.raise_app_exceptions:
raise app_exception
else:
break

yield chunk

if response_complete.is_set():
break

async def aclose() -> None:
await aclose_app(app_exception)

stream = AsyncIteratorStream(aiter_response_body_chunks(), close_func=aclose)
if app_exception is not None:
await aclose()
if self.raise_app_exceptions:
raise app_exception

assert status_code is not None
assert response_headers is not None
stream = AsyncIteratorStream(aiter_response_body(), close_func=aclose)
return (b"HTTP/1.1", status_code, b"", response_headers, stream)


Expand Down
19 changes: 10 additions & 9 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,45 @@ async def raise_exc_after_response(scope, receive, send):
raise ValueError()


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi():
client = httpx.AsyncClient(app=hello_world)
response = await client.get("http://www.example.org/")
assert response.status_code == 200
assert response.text == "Hello, World!"


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_upload():
client = httpx.AsyncClient(app=echo_body)
response = await client.post("http://www.example.org/", data=b"example")
assert response.status_code == 200
assert response.text == "example"


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(ValueError):
await client.get("http://www.example.org/")


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_http_error():
client = httpx.AsyncClient(app=partial(raise_exc, exc=httpx.HTTPError))
with pytest.raises(httpx.HTTPError):
await client.get("http://www.example.org/")


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
with pytest.raises(ValueError):
await client.get("http://www.example.org/")


async def test_asgi_disconnect_after_response_complete(async_environment):
@pytest.mark.usefixtures("async_environment")
async def test_asgi_disconnect_after_response_complete():
disconnect = False

async def read_body(scope, receive, send):
Expand Down Expand Up @@ -113,7 +114,7 @@ async def read_body(scope, receive, send):
assert disconnect


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming():
client = httpx.AsyncClient(app=hello_world)
async with client.stream("GET", "http://www.example.org/") as response:
Expand All @@ -122,15 +123,15 @@ async def test_asgi_streaming():
assert text == "Hello, World!"


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/"):
pass # pragma: no cover


@pytest.mark.asyncio
@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
async with client.stream("GET", "http://www.example.org/") as response:
Expand Down

0 comments on commit e6d6c6c

Please sign in to comment.