Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for streaming responses to ASGITransport #998

Closed
wants to merge 9 commits into from
198 changes: 113 additions & 85 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
from typing import TYPE_CHECKING, Callable, List, Mapping, Optional, Tuple, Union
import sys
from typing import AsyncIterator, Callable, List, Mapping, Optional, Tuple

import httpcore
import sniffio

if TYPE_CHECKING: # pragma: no cover
import asyncio

import trio

Event = Union[asyncio.Event, trio.Event]


def create_event() -> "Event":
if sniffio.current_async_library() == "trio":
import trio

return trio.Event()
else:
import asyncio

return asyncio.Event()
try:
from contextlib import asynccontextmanager # type: ignore # Python 3.6.
except ImportError: # pragma: no cover # Python 3.6.
from async_generator import asynccontextmanager # type: ignore


class ASGITransport(httpcore.AsyncHTTPTransport):
Expand Down Expand Up @@ -62,6 +49,11 @@ def __init__(
root_path: str = "",
client: Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
try:
import anyio # noqa
except ImportError: # pragma: no cover
raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)")

self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
Expand All @@ -75,84 +67,120 @@ async def request(
stream: httpcore.AsyncByteStream = None,
timeout: Mapping[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:

headers = [] if headers is None else headers
stream = httpcore.PlainByteStream(content=b"") if stream is None else stream

# ASGI scope.
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method.decode(),
"headers": headers,
"scheme": scheme.decode("ascii"),
"path": path.decode("ascii"),
"query_string": query,
"server": (host.decode("ascii"), port),
"client": self.client,
"root_path": self.root_path,
}

# Request.
request_body_chunks = stream.__aiter__()
request_complete = False

# Response.
status_code = None
response_headers = None
body_parts = []
response_started = False
response_complete = create_event()

# ASGI callables.

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
app_context = run_asgi(
self.app,
method,
url,
headers,
stream,
client=self.client,
root_path=self.root_path,
)

async def send(message: dict) -> None:
nonlocal status_code, response_headers, response_started
status_code, response_headers, response_body = await app_context.__aenter__()

if message["type"] == "http.response.start":
assert not response_started
async def aclose() -> None:
await app_context.__aexit__(*sys.exc_info())

status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
stream = httpcore.AsyncIteratorByteStream(response_body, aclose_func=aclose)

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)
return (b"HTTP/1.1", status_code, b"", response_headers, stream)

if body and method != b"HEAD":
body_parts.append(body)

if not more_body:
response_complete.set()
@asynccontextmanager
async def run_asgi(
app: Callable,
method: bytes,
url: Tuple[bytes, bytes, Optional[int], bytes],
headers: List[Tuple[bytes, bytes]],
stream: httpcore.AsyncByteStream,
*,
client: str,
root_path: str,
) -> AsyncIterator[Tuple[int, List[Tuple[bytes, bytes]], AsyncIterator[bytes]]]:
import anyio

# ASGI scope.
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method.decode(),
"headers": headers,
"scheme": scheme.decode("ascii"),
"path": path.decode("ascii"),
"query_string": query,
"server": (host.decode("ascii"), port),
"client": client,
"root_path": root_path,
}

# Request.
request_body_chunks = stream.__aiter__()
request_complete = False

# Response.
status_code: Optional[int] = None
response_headers: Optional[List[Tuple[bytes, bytes]]] = None
response_body_queue = anyio.create_queue(1)
response_started = anyio.create_event()
response_complete = anyio.create_event()

async def receive() -> dict:
nonlocal request_complete

if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}

try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete.is_set():
raise
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
else:
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers

if message["type"] == "http.response.start":
assert not response_started.is_set()
status_code = message["status"]
response_headers = message.get("headers", [])
await response_started.set()

elif message["type"] == "http.response.body":
assert not response_complete.is_set()
body = message.get("body", b"")
more_body = message.get("more_body", False)

if body and method != b"HEAD":
await response_body_queue.put(body)

if not more_body:
await response_body_queue.put(None)
await response_complete.set()

async def body_iterator() -> AsyncIterator[bytes]:
while True:
chunk = await response_body_queue.get()
if chunk is None:
break
yield chunk

async with anyio.create_task_group() as task_group:
await task_group.spawn(app, scope, receive, send)

await response_started.wait()

assert response_complete.is_set()
assert status_code is not None
assert response_headers is not None

stream = httpcore.PlainByteStream(content=b"".join(body_parts))

return (b"HTTP/1.1", status_code, b"", response_headers, stream)
yield status_code, response_headers, body_iterator()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
-e .[http2]

# Optional
async_generator; python_version < '3.7'
anyio
brotlipy==0.7.*

# Documentation
Expand Down
31 changes: 30 additions & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import httpx

from .concurrency import sleep


async def hello_world(scope, receive, send):
status = 200
Expand Down Expand Up @@ -37,7 +39,8 @@ async def raise_exc_after_response(scope, receive, send):
headers = [(b"content-type", "text/plain"), (b"content-length", str(len(output)))]

await send({"type": "http.response.start", "status": status, "headers": headers})
await send({"type": "http.response.body", "body": output})
await send({"type": "http.response.body", "body": output, "more_body": True})
await sleep(0.001) # Let the transport detect that the response has started.
raise ValueError()


Expand Down Expand Up @@ -109,3 +112,29 @@ async def read_body(scope, receive, send):
response = await client.post("http://www.example.org/", data=b"example")
assert response.status_code == 200
assert disconnect


@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming():
client = httpx.AsyncClient(app=hello_world)
async with client.stream("GET", "http://www.example.org/") as response:
assert response.status_code == 200
text = "".join([chunk async for chunk in response.aiter_text()])
assert text == "Hello, World!"


@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc():
client = httpx.AsyncClient(app=raise_exc)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/"):
pass # pragma: no cover


@pytest.mark.usefixtures("async_environment")
async def test_asgi_streaming_exc_after_response():
client = httpx.AsyncClient(app=raise_exc_after_response)
with pytest.raises(ValueError):
async with client.stream("GET", "http://www.example.org/") as response:
async for _ in response.aiter_bytes():
pass # pragma: no cover