Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for streaming responses to ASGITransport #998

Closed
wants to merge 9 commits into from
142 changes: 126 additions & 16 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import httpcore
import sniffio

from .._content_streams import ByteStream
from .._content_streams import AsyncIteratorStream, ByteStream
from .._utils import warn_deprecated

if typing.TYPE_CHECKING: # pragma: no cover
Expand All @@ -25,6 +25,76 @@ def create_event() -> "Event":
return asyncio.Event()


async def create_background_task(async_fn: typing.Callable) -> typing.Callable:
if sniffio.current_async_library() == "trio":
import trio

nursery_manager = trio.open_nursery()
nursery = await nursery_manager.__aenter__()
nursery.start_soon(async_fn)

async def aclose(exc: Exception = None) -> None:
if exc is not None:
await nursery_manager.__aexit__(type(exc), exc, exc.__traceback__)
else:
await nursery_manager.__aexit__(None, None, None)

return aclose

else:
import asyncio

loop = asyncio.get_event_loop()
task = loop.create_task(async_fn())

async def aclose(exc: Exception = None) -> None:
if not task.done():
task.cancel()

return aclose


def create_channel(
capacity: int,
) -> typing.Tuple[
typing.Callable[[], typing.Awaitable[bytes]],
typing.Callable[[bytes], typing.Awaitable[None]],
]:
if sniffio.current_async_library() == "trio":
import trio

send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)
return receive_channel.receive, send_channel.send

else:
import asyncio

queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
return queue.get, queue.put


async def run_until_first_complete(*async_fns: typing.Callable) -> None:
if sniffio.current_async_library() == "trio":
import trio

async with trio.open_nursery() as nursery:

async def run(async_fn: typing.Callable) -> None:
await async_fn()
nursery.cancel_scope.cancel()

for async_fn in async_fns:
nursery.start_soon(run, async_fn)

else:
import asyncio

coros = [async_fn() for async_fn in async_fns]
done, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved


class ASGITransport(httpcore.AsyncHTTPTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app.
Expand Down Expand Up @@ -95,18 +165,20 @@ async def request(
}
status_code = None
response_headers = None
body_parts = []
consume_response_body_chunk, produce_response_body_chunk = create_channel(1)
request_complete = False
response_started = False
response_started = create_event()
response_complete = create_event()
app_crashed = create_event()
app_exception: typing.Optional[Exception] = None

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

request_body_chunks = stream.__aiter__()

async def receive() -> dict:
nonlocal request_complete, response_complete
nonlocal request_complete

if request_complete:
await response_complete.wait()
Expand All @@ -120,38 +192,76 @@ async def receive() -> dict:
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers, body_parts
nonlocal response_started, response_complete
nonlocal status_code, response_headers

if message["type"] == "http.response.start":
assert not response_started
assert not response_started.is_set()

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
response_started.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
body_parts.append(body)
await produce_response_body_chunk(body)

if not more_body:
response_complete.set()

try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete:
raise
async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send)
except Exception as exc:
app_exception = exc
app_crashed.set()

aclose_app = await create_background_task(run_app)

await run_until_first_complete(app_crashed.wait, response_started.wait)
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved

assert response_complete.is_set()
if app_crashed.is_set():
assert app_exception is not None
await aclose_app(app_exception)
if self.raise_app_exceptions or not response_started.is_set():
raise app_exception

assert response_started.is_set()
assert status_code is not None
assert response_headers is not None

stream = ByteStream(b"".join(body_parts))
async def aiter_response_body_chunks() -> typing.AsyncIterator[bytes]:
chunk = b""

async def consume_chunk() -> None:
nonlocal chunk
chunk = await consume_response_body_chunk()

while True:
await run_until_first_complete(
app_crashed.wait, consume_chunk, response_complete.wait
)
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved

if app_crashed.is_set():
assert app_exception is not None
if self.raise_app_exceptions:
raise app_exception
else:
break

yield chunk

if response_complete.is_set():
break

async def aclose() -> None:
await aclose_app(app_exception)

stream = AsyncIteratorStream(aiter_response_body_chunks(), close_func=aclose)

return (b"HTTP/1.1", status_code, b"", response_headers, stream)

Expand Down
31 changes: 30 additions & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import httpx

from .concurrency import sleep


async def hello_world(scope, receive, send):
status = 200
Expand Down Expand Up @@ -35,7 +37,8 @@ async def raise_exc_after_response(scope, receive, send):
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await send({"type": "http.response.body", "body": output})
await send({"type": "http.response.body", "body": output, "more_body": True})
await sleep(0.001) # Let the transport detect that the response has started.
raise ValueError()


Expand Down Expand Up @@ -99,3 +102,29 @@ async def read_body(scope, receive, send):
response = await client.post("http://www.example.org/", data=b"example")
assert response.status_code == 200
assert disconnect


@pytest.mark.asyncio
async def test_asgi_streaming():
client = httpx.AsyncClient(app=hello_world)
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
text = "".join([chunk async for chunk in response.aiter_text()])
assert text == "Hello, World!"


@pytest.mark.asyncio
async def test_asgi_streaming_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/"):
pass # pragma: no cover


@pytest.mark.asyncio
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
async with client.stream("GET", "http://www.example.org/") as response:
with pytest.raises(ValueError):
async for _ in response.aiter_bytes():
pass # pragma: no cover