Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow inspecting response body #1695

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 58 additions & 21 deletions starlette/middleware/http.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import AsyncGenerator, Callable, Optional, Union
from contextlib import AsyncExitStack
from typing import AsyncGenerator, AsyncIterable, Callable, Optional, Union

from .._compat import aclosing
from ..datastructures import MutableHeaders
from ..requests import HTTPConnection
from ..responses import Response
from ..responses import Response, StreamingResponse
from ..types import ASGIApp, Message, Receive, Scope, Send

# This type hint not exposed, as it exists mostly for our own documentation purposes.
Expand Down Expand Up @@ -49,7 +50,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

conn = HTTPConnection(scope)

async with aclosing(self._dispatch_func(conn)) as flow:
async with AsyncExitStack() as stack:
flow = await stack.enter_async_context(aclosing(self._dispatch_func(conn)))

# Kick the flow until the first `yield`.
# Might respond early before we call into the app.
maybe_early_response = await flow.__anext__()
Expand All @@ -67,30 +70,64 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

response_started = False

async def wrapped_send(message: Message) -> None:
async def _wrapped_send() -> AsyncGenerator[None, Optional[Message]]:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True

response = Response(status_code=message["status"])
response.raw_headers.clear()

try:
await flow.asend(response)
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")
message = yield
assert message is not None
assert message["type"] == "http.response.start"
response_started = True
sent_start_response = False
start_message = message

async def ensure_start_response() -> None:
if not sent_start_response:
await send(start_message)

stack.push_async_callback(ensure_start_response)

message = yield
assert message is not None
assert message["type"] == "http.response.body"
headers = MutableHeaders(raw=start_message["headers"])
if message.get("more_body", False) is False:
response = Response(
status_code=start_message["status"],
headers=headers,
content=message.get("body", b""),
)
else:
async def _resp_stream() -> AsyncGenerator[bytes, None]:
raise NotImplementedError
yield

resp_stream = await stack.enter_async_context(aclosing(_resp_stream()))
response = StreamingResponse(
content=resp_stream,
status_code=start_message["status"],
headers=headers,
)
try:
await flow.asend(response)
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")
start_message["headers"] = response.headers.raw
await send(start_message)
sent_start_response = True
await send(message)

headers = MutableHeaders(raw=message["headers"])
headers.update(response.headers)
message["headers"] = headers.raw
while True:
message = yield
assert message is not None
await send(message)

await send(message)
wrapped_send = await stack.enter_async_context(aclosing(_wrapped_send()))
await wrapped_send.asend(None)

try:
await self.app(scope, receive, wrapped_send)
await self.app(scope, receive, wrapped_send.asend)
except Exception as exc:
if response_started:
raise
Expand Down
26 changes: 25 additions & 1 deletion tests/middleware/test_http.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncGenerator, Callable, Iterator, Optional
from typing import AsyncGenerator, AsyncIterable, Callable, Iterator, Optional

import pytest

Expand Down Expand Up @@ -231,3 +231,27 @@ def test_no_dispatch_given(
client = test_client_factory(app)
with pytest.raises(NotImplementedError, match="No dispatch implementation"):
client.get("/")


def test_response_body_not_streaming(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def index(request: Request) -> Response:
return Response(b"foo")

class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
resp = yield
assert resp.body == b"foo"

app = Starlette(
routes=[Route("/", index)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
resp = client.get("/")
assert resp.status_code == 200
assert resp.content == b"foo"