Skip to content

Commit

Permalink
Add support for streaming responses to ASGITransport
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed May 25, 2020
1 parent 66a4537 commit c70d345
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 17 deletions.
141 changes: 125 additions & 16 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import httpcore
import sniffio

from .._content_streams import ByteStream
from .._content_streams import AsyncIteratorStream, ByteStream
from .._utils import warn_deprecated

if typing.TYPE_CHECKING: # pragma: no cover
Expand All @@ -25,6 +25,75 @@ def create_event() -> "Event":
return asyncio.Event()


async def create_background_task(async_fn: typing.Callable) -> typing.Callable:
if sniffio.current_async_library() == "trio":
import trio

nursery_manager = trio.open_nursery()
nursery = await nursery_manager.__aenter__()
nursery.start_soon(async_fn)

async def aclose(exc: Exception = None) -> None:
if exc is not None:
await nursery_manager.__aexit__(type(exc), exc, exc.__traceback__)
else:
await nursery_manager.__aexit__(None, None, None)

return aclose

else:
import asyncio

task = asyncio.create_task(async_fn())

async def aclose(exc: Exception = None) -> None:
if not task.done():
task.cancel()

return aclose


def create_channel(
capacity: int,
) -> typing.Tuple[
typing.Callable[[], typing.Awaitable[bytes]],
typing.Callable[[bytes], typing.Awaitable[None]],
]:
if sniffio.current_async_library() == "trio":
import trio

send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)
return receive_channel.receive, send_channel.send

else:
import asyncio

queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
return queue.get, queue.put


async def run_until_first_complete(*async_fns: typing.Callable) -> None:
if sniffio.current_async_library() == "trio":
import trio

async with trio.open_nursery() as nursery:

async def run(async_fn: typing.Callable) -> None:
await async_fn()
nursery.cancel_scope.cancel()

for async_fn in async_fns:
nursery.start_soon(run, async_fn)

else:
import asyncio

coros = [async_fn() for async_fn in async_fns]
done, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()


class ASGITransport(httpcore.AsyncHTTPTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
Expand Down Expand Up @@ -95,18 +164,20 @@ async def request(
}
status_code = None
response_headers = None
body_parts = []
consume_response_body_chunk, produce_response_body_chunk = create_channel(1)
request_complete = False
response_started = False
response_started = create_event()
response_complete = create_event()
app_crashed = create_event()
app_exception: typing.Optional[Exception] = None

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

request_body_chunks = stream.__aiter__()

async def receive() -> dict:
nonlocal request_complete, response_complete
nonlocal request_complete

if request_complete:
await response_complete.wait()
Expand All @@ -120,38 +191,76 @@ async def receive() -> dict:
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers, body_parts
nonlocal response_started, response_complete
nonlocal status_code, response_headers

if message["type"] == "http.response.start":
assert not response_started
assert not response_started.is_set()

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
response_started.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
body_parts.append(body)
await produce_response_body_chunk(body)

if not more_body:
response_complete.set()

try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete:
raise
async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send)
except Exception as exc:
app_exception = exc
app_crashed.set()

aclose_app = await create_background_task(run_app)

await run_until_first_complete(app_crashed.wait, response_started.wait)

assert response_complete.is_set()
if app_crashed.is_set():
assert app_exception is not None
await aclose_app(app_exception)
if self.raise_app_exceptions or not response_started.is_set():
raise app_exception

assert response_started.is_set()
assert status_code is not None
assert response_headers is not None

stream = ByteStream(b"".join(body_parts))
async def aiter_response_body_chunks() -> typing.AsyncIterator[bytes]:
chunk = b""

async def consume_chunk() -> None:
nonlocal chunk
chunk = await consume_response_body_chunk()

while True:
await run_until_first_complete(
app_crashed.wait, consume_chunk, response_complete.wait
)

if app_crashed.is_set():
assert app_exception is not None
if self.raise_app_exceptions:
raise app_exception
else:
break

yield chunk

if response_complete.is_set():
break

async def aclose() -> None:
await aclose_app(app_exception)

stream = AsyncIteratorStream(aiter_response_body_chunks(), close_func=aclose)

return (b"HTTP/1.1", status_code, b"", response_headers, stream)

Expand Down
31 changes: 30 additions & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import httpx

from .concurrency import sleep


async def hello_world(scope, receive, send):
status = 200
Expand Down Expand Up @@ -35,7 +37,8 @@ async def raise_exc_after_response(scope, receive, send):
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await send({"type": "http.response.body", "body": output})
await send({"type": "http.response.body", "body": output, "more_body": True})
await sleep(0.001) # Let the transport detect that the response has started.
raise ValueError()


Expand Down Expand Up @@ -99,3 +102,29 @@ async def read_body(scope, receive, send):
response = await client.post("http://www.example.org/", data=b"example")
assert response.status_code == 200
assert disconnect


@pytest.mark.asyncio
async def test_asgi_streaming():
client = httpx.AsyncClient(app=hello_world)
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
text = "".join([chunk async for chunk in response.aiter_text()])
assert text == "Hello, World!"


@pytest.mark.asyncio
async def test_asgi_streaming_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/"):
pass # pragma: no cover


@pytest.mark.asyncio
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
async with client.stream("GET", "http://www.example.org/") as response:
with pytest.raises(ValueError):
async for _ in response.aiter_bytes():
pass # pragma: no cover

0 comments on commit c70d345

Please sign in to comment.