From 672f4597c7c81954c7b42b15daad4a6ab1c34ed6 Mon Sep 17 00:00:00 2001 From: Jean Hominal Date: Tue, 16 Jan 2024 01:15:55 +0100 Subject: [PATCH] Stream response body in ASGITransport Fixes #2186 --- httpx/_transports/asgi.py | 85 +++++++++++++++++++++++++--- tests/test_asgi.py | 114 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+), 8 deletions(-) diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index ed8b3f1c38..a6d3125e0d 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,3 +1,4 @@ +import types import typing import sniffio @@ -33,12 +34,75 @@ def create_event() -> "Event": return asyncio.Event() +class _AwaitableRunner: + def __init__(self, awaitable: typing.Awaitable[typing.Any]): + self._generator = awaitable.__await__() + self._started = False + self._next_item: typing.Any = None + self._finished = False + + @types.coroutine + def __call__( + self, *, until: typing.Optional[typing.Callable[[], bool]] = None + ) -> typing.Generator[typing.Any, typing.Any, typing.Any]: + while not self._finished and (until is None or not until()): + send_value, throw_value = None, None + if self._started: + try: + send_value = yield self._next_item + except BaseException as e: + throw_value = e + + self._started = True + try: + if throw_value is not None: + self._next_item = self._generator.throw(throw_value) + else: + self._next_item = self._generator.send(send_value) + except StopIteration as e: + self._finished = True + return e.value + except BaseException: + self._generator.close() + self._finished = True + raise + + class ASGIResponseStream(AsyncByteStream): - def __init__(self, body: typing.List[bytes]) -> None: + def __init__( + self, + body: typing.List[bytes], + raise_app_exceptions: bool, + response_complete: "Event", + app_runner: _AwaitableRunner, + ) -> None: self._body = body + self._raise_app_exceptions = raise_app_exceptions + self._response_complete = response_complete + self._app_runner = app_runner async def __aiter__(self) -> typing.AsyncIterator[bytes]: - yield b"".join(self._body) + try: + while bool(self._body) or not self._response_complete.is_set(): + if self._body: + yield b"".join(self._body) + self._body.clear() + await self._app_runner( + until=lambda: bool(self._body) or self._response_complete.is_set() + ) + except Exception: # noqa: PIE786 + if self._raise_app_exceptions: + raise + finally: + await self.aclose() + + async def aclose(self) -> None: + self._response_complete.set() + try: + await self._app_runner() + except Exception: # noqa: PIE786 + if self._raise_app_exceptions: + raise class ASGITransport(AsyncBaseTransport): @@ -145,8 +209,10 @@ async def send(message: _Message) -> None: response_headers = message.get("headers", []) response_started = True - elif message["type"] == "http.response.body": - assert not response_complete.is_set() + elif ( + message["type"] == "http.response.body" + and not response_complete.is_set() + ): body = message.get("body", b"") more_body = message.get("more_body", False) @@ -156,9 +222,11 @@ async def send(message: _Message) -> None: if not more_body: response_complete.set() + app_runner = _AwaitableRunner(self.app(scope, receive, send)) + try: - await self.app(scope, receive, send) - except Exception: # noqa: PIE-786 + await app_runner(until=lambda: response_started) + except Exception: # noqa: PIE786 if self.raise_app_exceptions: raise @@ -168,10 +236,11 @@ async def send(message: _Message) -> None: if response_headers is None: response_headers = {} - assert response_complete.is_set() assert status_code is not None assert response_headers is not None - stream = ASGIResponseStream(body_parts) + stream = ASGIResponseStream( + body_parts, self.raise_app_exceptions, response_complete, app_runner + ) return Response(status_code, headers=response_headers, stream=stream) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 2971506097..21a2d7cf30 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -1,5 +1,6 @@ import json +import anyio import pytest import httpx @@ -60,6 +61,16 @@ async def raise_exc(scope, receive, send): raise RuntimeError() +async def raise_exc_after_response_start(scope, receive, send): + status = 200 + output = b"Hello, World!" + headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))] + + await send({"type": "http.response.start", "status": status, "headers": headers}) + await anyio.sleep(0) + raise RuntimeError() + + async def raise_exc_after_response(scope, receive, send): status = 200 output = b"Hello, World!" @@ -67,6 +78,7 @@ async def raise_exc_after_response(scope, receive, send): await send({"type": "http.response.start", "status": status, "headers": headers}) await send({"type": "http.response.body", "body": output}) + await anyio.sleep(0) raise RuntimeError() @@ -165,6 +177,14 @@ async def test_asgi_exc(): await client.get("http://www.example.org/") +@pytest.mark.anyio +async def test_asgi_exc_after_response_start(): + transport = httpx.ASGITransport(app=raise_exc_after_response_start) + async with httpx.AsyncClient(transport=transport) as client: + with pytest.raises(RuntimeError): + await client.get("http://www.example.org/") + + @pytest.mark.anyio async def test_asgi_exc_after_response(): async with httpx.AsyncClient(app=raise_exc_after_response) as client: @@ -213,3 +233,97 @@ async def test_asgi_exc_no_raise(): response = await client.get("http://www.example.org/") assert response.status_code == 500 + + +@pytest.mark.anyio +async def test_asgi_exc_no_raise_after_response_start(): + transport = httpx.ASGITransport( + app=raise_exc_after_response_start, raise_app_exceptions=False + ) + async with httpx.AsyncClient(transport=transport) as client: + response = await client.get("http://www.example.org/") + + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_asgi_exc_no_raise_after_response(): + transport = httpx.ASGITransport( + app=raise_exc_after_response, raise_app_exceptions=False + ) + async with httpx.AsyncClient(transport=transport) as client: + response = await client.get("http://www.example.org/") + + assert response.status_code == 200 + + +@pytest.mark.anyio +async def test_asgi_stream_returns_before_waiting_for_body(): + start_response_body = anyio.Event() + + async def send_response_body_after_event(scope, receive, send): + status = 200 + headers = [(b"content-type", b"text/plain")] + await send( + {"type": "http.response.start", "status": status, "headers": headers} + ) + await start_response_body.wait() + await send({"type": "http.response.body", "body": b"body", "more_body": False}) + + transport = httpx.ASGITransport(app=send_response_body_after_event) + async with httpx.AsyncClient(transport=transport) as client: + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 + start_response_body.set() + await response.aread() + assert response.text == "body" + + +@pytest.mark.anyio +async def test_asgi_stream_allows_iterative_streaming(): + stream_events = [anyio.Event() for i in range(4)] + + async def send_response_body_after_event(scope, receive, send): + status = 200 + headers = [(b"content-type", b"text/plain")] + await send( + {"type": "http.response.start", "status": status, "headers": headers} + ) + for e in stream_events: + await e.wait() + await send( + { + "type": "http.response.body", + "body": b"chunk", + "more_body": e is not stream_events[-1], + } + ) + + transport = httpx.ASGITransport(app=send_response_body_after_event) + async with httpx.AsyncClient(transport=transport) as client: + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 + iterator = response.aiter_raw() + for e in stream_events: + e.set() + assert await iterator.__anext__() == b"chunk" + with pytest.raises(StopAsyncIteration): + await iterator.__anext__() + + +@pytest.mark.anyio +async def test_asgi_can_be_canceled(): + # This test exists to cover transmission of the cancellation exception through + # _AwaitableRunner + app_started = anyio.Event() + + async def never_return(scope, receive, send): + app_started.set() + await anyio.sleep_forever() + + transport = httpx.ASGITransport(app=never_return) + async with httpx.AsyncClient(transport=transport) as client: + async with anyio.create_task_group() as task_group: + task_group.start_soon(client.get, "http://www.example.org/") + await app_started.wait() + task_group.cancel_scope.cancel()