From 6152052a36ba5f804b00dec5eac12d3e6fcfead4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 8 Nov 2022 13:35:15 +0000 Subject: [PATCH 1/2] Add test cases --- tests/test_asgi.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 60f55dfd6f..8e6c6f2c62 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -191,3 +191,29 @@ async def read_body(scope, receive, send): assert response.status_code == 200 assert disconnect + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_streaming(): + client = httpx.AsyncClient(app=hello_world) + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 + text = "".join([chunk async for chunk in response.aiter_text()]) + assert text == "Hello, World!" + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_streaming_exc(): + client = httpx.AsyncClient(app=raise_exc) + with pytest.raises(ValueError): + async with client.stream("GET", "http://www.example.org/"): + pass # pragma: no cover + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_streaming_exc_after_response(): + client = httpx.AsyncClient(app=raise_exc_after_response) + with pytest.raises(ValueError): + async with client.stream("GET", "http://www.example.org/") as response: + async for _ in response.aiter_bytes(): + pass # pragma: no cover From 3b7717b246300f2bcae3a55b46320e60fbd021e2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 8 Nov 2022 14:13:04 +0000 Subject: [PATCH 2/2] Push implementation into 'run_asgi' function --- httpx/_transports/asgi.py | 169 +++++++++++++++++++++----------------- 1 file changed, 95 insertions(+), 74 deletions(-) diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index 711a6f6ce7..478680caf9 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -82,82 +82,103 @@ async def handle_async_request( self, request: Request, ) -> Response: - assert isinstance(request.stream, AsyncByteStream) - - # ASGI scope. - scope = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": request.method, - "headers": [(k.lower(), v) for (k, v) in request.headers.raw], - "scheme": request.url.scheme, - "path": request.url.path, - "raw_path": request.url.raw_path, - "query_string": request.url.query, - "server": (request.url.host, request.url.port), - "client": self.client, - "root_path": self.root_path, - } - - # Request. - request_body_chunks = request.stream.__aiter__() - request_complete = False - - # Response. - status_code = None - response_headers = None - body_parts = [] - response_started = False - response_complete = create_event() - - # ASGI callables. - - async def receive() -> dict: - nonlocal request_complete - - if request_complete: - await response_complete.wait() - return {"type": "http.disconnect"} - - try: - body = await request_body_chunks.__anext__() - except StopAsyncIteration: - request_complete = True - return {"type": "http.request", "body": b"", "more_body": False} - return {"type": "http.request", "body": body, "more_body": True} - - async def send(message: dict) -> None: - nonlocal status_code, response_headers, response_started - - if message["type"] == "http.response.start": - assert not response_started - - status_code = message["status"] - response_headers = message.get("headers", []) - response_started = True - - elif message["type"] == "http.response.body": - assert not response_complete.is_set() - body = message.get("body", b"") - more_body = message.get("more_body", False) - - if body and request.method != "HEAD": - body_parts.append(body) - - if not more_body: - response_complete.set() + try: + import anyio # noqa + except ImportError: # pragma: no cover + raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)") + + return await run_asgi( + request, + app=self.app, + raise_app_exceptions=self.raise_app_exceptions, + root_path=self.root_path, + client=self.client, + ) + + +async def run_asgi( + request: Request, + app: typing.Callable, + raise_app_exceptions: bool, + root_path: str, + client: typing.Tuple[str, int], +) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path, + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": client, + "root_path": root_path, + } + + # Request. + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response. + status_code = None + response_headers = None + body_parts = [] + response_started = False + response_complete = create_event() + + # ASGI callables. + + async def receive() -> dict: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} try: - await self.app(scope, receive, send) - except Exception: # noqa: PIE-786 - if self.raise_app_exceptions or not response_complete.is_set(): - raise + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: dict) -> None: + nonlocal status_code, response_headers, response_started + + if message["type"] == "http.response.start": + assert not response_started + + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + elif message["type"] == "http.response.body": + assert not response_complete.is_set() + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + body_parts.append(body) + + if not more_body: + response_complete.set() + + try: + await app(scope, receive, send) + except Exception: # noqa: PIE-786 + if raise_app_exceptions or not response_complete.is_set(): + raise - assert response_complete.is_set() - assert status_code is not None - assert response_headers is not None + assert response_complete.is_set() + assert status_code is not None + assert response_headers is not None - stream = ASGIResponseStream(body_parts) + stream = ASGIResponseStream(body_parts) - return Response(status_code, headers=response_headers, stream=stream) + return Response(status_code, headers=response_headers, stream=stream)