Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream response body in ASGITransport #3059

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 81 additions & 14 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import types
import typing

import sniffio
Expand All @@ -16,11 +17,9 @@

_Message = typing.Dict[str, typing.Any]
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
_Send = typing.Callable[
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
]
_Send = typing.Callable[[_Message], typing.Awaitable[None]]
_ASGIApp = typing.Callable[
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
]


Expand All @@ -35,12 +34,75 @@ def create_event() -> "Event":
return asyncio.Event()


class _AwaitableRunner:
def __init__(self, awaitable: typing.Awaitable[typing.Any]):
self._generator = awaitable.__await__()
self._started = False
self._next_item: typing.Any = None
self._finished = False

@types.coroutine
def __call__(
self, *, until: typing.Optional[typing.Callable[[], bool]] = None
) -> typing.Generator[typing.Any, typing.Any, typing.Any]:
while not self._finished and (until is None or not until()):
send_value, throw_value = None, None
if self._started:
try:
send_value = yield self._next_item
except BaseException as e:
throw_value = e

self._started = True
try:
if throw_value is not None:
self._next_item = self._generator.throw(throw_value)
else:
self._next_item = self._generator.send(send_value)
except StopIteration as e:
self._finished = True
return e.value
except BaseException:
self._generator.close()
self._finished = True
raise


class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: typing.List[bytes]) -> None:
def __init__(
self,
body: typing.List[bytes],
raise_app_exceptions: bool,
response_complete: "Event",
app_runner: _AwaitableRunner,
) -> None:
self._body = body
self._raise_app_exceptions = raise_app_exceptions
self._response_complete = response_complete
self._app_runner = app_runner

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
yield b"".join(self._body)
try:
while bool(self._body) or not self._response_complete.is_set():
if self._body:
yield b"".join(self._body)
self._body.clear()
await self._app_runner(
until=lambda: bool(self._body) or self._response_complete.is_set()
)
except Exception: # noqa: PIE786
if self._raise_app_exceptions:
raise
finally:
await self.aclose()

async def aclose(self) -> None:
self._response_complete.set()
try:
await self._app_runner()
except Exception: # noqa: PIE786
if self._raise_app_exceptions:
raise


class ASGITransport(AsyncBaseTransport):
Expand Down Expand Up @@ -123,7 +185,7 @@ async def handle_async_request(

# ASGI callables.

async def receive() -> typing.Dict[str, typing.Any]:
async def receive() -> _Message:
nonlocal request_complete

if request_complete:
Expand All @@ -137,7 +199,7 @@ async def receive() -> typing.Dict[str, typing.Any]:
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: typing.Dict[str, typing.Any]) -> None:
async def send(message: _Message) -> None:
nonlocal status_code, response_headers, response_started

if message["type"] == "http.response.start":
Expand All @@ -147,8 +209,10 @@ async def send(message: typing.Dict[str, typing.Any]) -> None:
response_headers = message.get("headers", [])
response_started = True

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
elif (
message["type"] == "http.response.body"
and not response_complete.is_set()
):
body = message.get("body", b"")
more_body = message.get("more_body", False)

Expand All @@ -158,9 +222,11 @@ async def send(message: typing.Dict[str, typing.Any]) -> None:
if not more_body:
response_complete.set()

app_runner = _AwaitableRunner(self.app(scope, receive, send))

try:
await self.app(scope, receive, send)
except Exception: # noqa: PIE-786
await app_runner(until=lambda: response_started)
except Exception: # noqa: PIE786
if self.raise_app_exceptions:
raise

Expand All @@ -170,10 +236,11 @@ async def send(message: typing.Dict[str, typing.Any]) -> None:
if response_headers is None:
response_headers = {}

assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

stream = ASGIResponseStream(body_parts)
stream = ASGIResponseStream(
body_parts, self.raise_app_exceptions, response_complete, app_runner
)

return Response(status_code, headers=response_headers, stream=stream)
114 changes: 114 additions & 0 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

import anyio
import pytest

import httpx
Expand Down Expand Up @@ -60,13 +61,24 @@ async def raise_exc(scope, receive, send):
raise RuntimeError()


async def raise_exc_after_response_start(scope, receive, send):
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await anyio.sleep(0)
raise RuntimeError()


async def raise_exc_after_response(scope, receive, send):
status = 200
output = b"Hello, World!"
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await send({"type": "http.response.body", "body": output})
await anyio.sleep(0)
raise RuntimeError()


Expand Down Expand Up @@ -165,6 +177,14 @@ async def test_asgi_exc():
await client.get("http://www.example.org/")


@pytest.mark.anyio
async def test_asgi_exc_after_response_start():
transport = httpx.ASGITransport(app=raise_exc_after_response_start)
async with httpx.AsyncClient(transport=transport) as client:
with pytest.raises(RuntimeError):
await client.get("http://www.example.org/")


@pytest.mark.anyio
async def test_asgi_exc_after_response():
async with httpx.AsyncClient(app=raise_exc_after_response) as client:
Expand Down Expand Up @@ -213,3 +233,97 @@ async def test_asgi_exc_no_raise():
response = await client.get("http://www.example.org/")

assert response.status_code == 500


@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response_start():
transport = httpx.ASGITransport(
app=raise_exc_after_response_start, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")

assert response.status_code == 200


@pytest.mark.anyio
async def test_asgi_exc_no_raise_after_response():
transport = httpx.ASGITransport(
app=raise_exc_after_response, raise_app_exceptions=False
)
async with httpx.AsyncClient(transport=transport) as client:
response = await client.get("http://www.example.org/")

assert response.status_code == 200


@pytest.mark.anyio
async def test_asgi_stream_returns_before_waiting_for_body():
start_response_body = anyio.Event()

async def send_response_body_after_event(scope, receive, send):
status = 200
headers = [(b"content-type", b"text/plain")]
await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
await start_response_body.wait()
await send({"type": "http.response.body", "body": b"body", "more_body": False})

transport = httpx.ASGITransport(app=send_response_body_after_event)
async with httpx.AsyncClient(transport=transport) as client:
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
start_response_body.set()
await response.aread()
assert response.text == "body"


@pytest.mark.anyio
async def test_asgi_stream_allows_iterative_streaming():
stream_events = [anyio.Event() for i in range(4)]

async def send_response_body_after_event(scope, receive, send):
status = 200
headers = [(b"content-type", b"text/plain")]
await send(
{"type": "http.response.start", "status": status, "headers": headers}
)
for e in stream_events:
await e.wait()
await send(
{
"type": "http.response.body",
"body": b"chunk",
"more_body": e is not stream_events[-1],
}
)

transport = httpx.ASGITransport(app=send_response_body_after_event)
async with httpx.AsyncClient(transport=transport) as client:
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
iterator = response.aiter_raw()
for e in stream_events:
e.set()
assert await iterator.__anext__() == b"chunk"
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()


@pytest.mark.anyio
async def test_asgi_can_be_canceled():
# This test exists to cover transmission of the cancellation exception through
# _AwaitableRunner
app_started = anyio.Event()

async def never_return(scope, receive, send):
app_started.set()
await anyio.sleep_forever()

transport = httpx.ASGITransport(app=never_return)
async with httpx.AsyncClient(transport=transport) as client:
async with anyio.create_task_group() as task_group:
task_group.start_soon(client.get, "http://www.example.org/")
await app_started.wait()
task_group.cancel_scope.cancel()