diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index a58e10a6d6..c894b0bdfa 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,25 +1,12 @@ -from typing import TYPE_CHECKING, Callable, List, Mapping, Optional, Tuple, Union +import sys +from typing import AsyncIterator, Callable, List, Mapping, Optional, Tuple import httpcore -import sniffio -if TYPE_CHECKING: # pragma: no cover - import asyncio - - import trio - - Event = Union[asyncio.Event, trio.Event] - - -def create_event() -> "Event": - if sniffio.current_async_library() == "trio": - import trio - - return trio.Event() - else: - import asyncio - - return asyncio.Event() +try: + from contextlib import asynccontextmanager # type: ignore # Python 3.6. +except ImportError: # pragma: no cover # Python 3.6. + from async_generator import asynccontextmanager # type: ignore class ASGITransport(httpcore.AsyncHTTPTransport): @@ -62,6 +49,11 @@ def __init__( root_path: str = "", client: Tuple[str, int] = ("127.0.0.1", 123), ) -> None: + try: + import anyio # noqa + except ImportError: # pragma: no cover + raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)") + self.app = app self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path @@ -75,84 +67,120 @@ async def request( stream: httpcore.AsyncByteStream = None, timeout: Mapping[str, Optional[float]] = None, ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]: + headers = [] if headers is None else headers stream = httpcore.PlainByteStream(content=b"") if stream is None else stream - # ASGI scope. - scheme, host, port, full_path = url - path, _, query = full_path.partition(b"?") - scope = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": method.decode(), - "headers": headers, - "scheme": scheme.decode("ascii"), - "path": path.decode("ascii"), - "query_string": query, - "server": (host.decode("ascii"), port), - "client": self.client, - "root_path": self.root_path, - } - - # Request. - request_body_chunks = stream.__aiter__() - request_complete = False - - # Response. - status_code = None - response_headers = None - body_parts = [] - response_started = False - response_complete = create_event() - - # ASGI callables. - - async def receive() -> dict: - nonlocal request_complete - - if request_complete: - await response_complete.wait() - return {"type": "http.disconnect"} - - try: - body = await request_body_chunks.__anext__() - except StopAsyncIteration: - request_complete = True - return {"type": "http.request", "body": b"", "more_body": False} - return {"type": "http.request", "body": body, "more_body": True} + app_context = run_asgi( + self.app, + method, + url, + headers, + stream, + client=self.client, + root_path=self.root_path, + ) - async def send(message: dict) -> None: - nonlocal status_code, response_headers, response_started + status_code, response_headers, response_body = await app_context.__aenter__() - if message["type"] == "http.response.start": - assert not response_started + async def aclose() -> None: + await app_context.__aexit__(*sys.exc_info()) - status_code = message["status"] - response_headers = message.get("headers", []) - response_started = True + stream = httpcore.AsyncIteratorByteStream(response_body, aclose_func=aclose) - elif message["type"] == "http.response.body": - assert not response_complete.is_set() - body = message.get("body", b"") - more_body = message.get("more_body", False) + return (b"HTTP/1.1", status_code, b"", response_headers, stream) - if body and method != b"HEAD": - body_parts.append(body) - if not more_body: - response_complete.set() +@asynccontextmanager +async def run_asgi( + app: Callable, + method: bytes, + url: Tuple[bytes, bytes, Optional[int], bytes], + headers: List[Tuple[bytes, bytes]], + stream: httpcore.AsyncByteStream, + *, + client: str, + root_path: str, +) -> AsyncIterator[Tuple[int, List[Tuple[bytes, bytes]], AsyncIterator[bytes]]]: + import anyio + + # ASGI scope. + scheme, host, port, full_path = url + path, _, query = full_path.partition(b"?") + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method.decode(), + "headers": headers, + "scheme": scheme.decode("ascii"), + "path": path.decode("ascii"), + "query_string": query, + "server": (host.decode("ascii"), port), + "client": client, + "root_path": root_path, + } + + # Request. + request_body_chunks = stream.__aiter__() + request_complete = False + + # Response. + status_code: Optional[int] = None + response_headers: Optional[List[Tuple[bytes, bytes]]] = None + response_body_queue = anyio.create_queue(1) + response_started = anyio.create_event() + response_complete = anyio.create_event() + + async def receive() -> dict: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} try: - await self.app(scope, receive, send) - except Exception: - if self.raise_app_exceptions or not response_complete.is_set(): - raise + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + else: + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: dict) -> None: + nonlocal status_code, response_headers + + if message["type"] == "http.response.start": + assert not response_started.is_set() + status_code = message["status"] + response_headers = message.get("headers", []) + await response_started.set() + + elif message["type"] == "http.response.body": + assert not response_complete.is_set() + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and method != b"HEAD": + await response_body_queue.put(body) + + if not more_body: + await response_body_queue.put(None) + await response_complete.set() + + async def body_iterator() -> AsyncIterator[bytes]: + while True: + chunk = await response_body_queue.get() + if chunk is None: + break + yield chunk + + async with anyio.create_task_group() as task_group: + await task_group.spawn(app, scope, receive, send) + + await response_started.wait() - assert response_complete.is_set() assert status_code is not None assert response_headers is not None - stream = httpcore.PlainByteStream(content=b"".join(body_parts)) - - return (b"HTTP/1.1", status_code, b"", response_headers, stream) + yield status_code, response_headers, body_iterator() diff --git a/requirements.txt b/requirements.txt index a901dbeaa8..23d23bcdc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -e .[http2] # Optional +async_generator; python_version < '3.7' +anyio brotlipy==0.7.* # Documentation diff --git a/tests/test_asgi.py b/tests/test_asgi.py index c59ef7c30e..35d08b95ff 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -4,6 +4,8 @@ import httpx +from .concurrency import sleep + async def hello_world(scope, receive, send): status = 200 @@ -37,7 +39,8 @@ async def raise_exc_after_response(scope, receive, send): headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))] await send({"type": "http.response.start", "status": status, "headers": headers}) - await send({"type": "http.response.body", "body": output}) + await send({"type": "http.response.body", "body": output, "more_body": True}) + await sleep(0.001) # Let the transport detect that the response has started. raise ValueError() @@ -109,3 +112,29 @@ async def read_body(scope, receive, send): response = await client.post("http://www.example.org/", data=b"example") assert response.status_code == 200 assert disconnect + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_streaming(): + client = httpx.AsyncClient(app=hello_world) + async with client.stream("GET", "http://www.example.org/") as response: + assert response.status_code == 200 + text = "".join([chunk async for chunk in response.aiter_text()]) + assert text == "Hello, World!" + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_streaming_exc(): + client = httpx.AsyncClient(app=raise_exc) + with pytest.raises(ValueError): + async with client.stream("GET", "http://www.example.org/"): + pass # pragma: no cover + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_streaming_exc_after_response(): + client = httpx.AsyncClient(app=raise_exc_after_response) + with pytest.raises(ValueError): + async with client.stream("GET", "http://www.example.org/") as response: + async for _ in response.aiter_bytes(): + pass # pragma: no cover