Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse #2620

Merged
merged 12 commits into from
Sep 1, 2024
35 changes: 24 additions & 11 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -56,6 +55,7 @@ async def wrapped_receive(self) -> Message:
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg

# wrapped_rcv state 3: not yet consumed
Expand Down Expand Up @@ -198,20 +198,33 @@ async def dispatch(
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the removal of background here intentional? We use background for both actual background tasks but also to detect when a streaming request was cancelled by the client (request.is_disconnected had some minor overhead).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cutting a PR for this

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
169 changes: 144 additions & 25 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Generator,
)

Expand All @@ -16,7 +17,7 @@
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
Expand Down Expand Up @@ -260,7 +261,6 @@ async def homepage(request: Request) -> PlainTextResponse:
@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()

Expand Down Expand Up @@ -293,13 +293,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand All @@ -313,7 +307,6 @@ async def send(message: Message) -> None:

@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
request_body_sent = False
response_complete = anyio.Event()
events: list[str | Message] = []

Expand Down Expand Up @@ -345,12 +338,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -379,7 +367,6 @@ async def send(message: Message) -> None:
@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()

Expand Down Expand Up @@ -424,13 +411,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -778,7 +759,9 @@ async def rcv() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"1", "more_body": True}
yield {"type": "http.request", "body": b"2", "more_body": True}
yield {"type": "http.request", "body": b"3"}
await anyio.sleep(float("inf"))
raise AssertionError( # pragma: no cover
"Should not be called, no need to poll for disconnect"
)

sent: list[Message] = []

Expand Down Expand Up @@ -1033,3 +1016,139 @@ async def endpoint(request: Request) -> Response:
resp.raise_for_status()

assert bodies == [b"Hello, World!-foo"]


@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
self.events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
return res

async def sleepy(request: Request) -> Response:
try:
await request.body()
except ClientDisconnect:
pass
else: # pragma: no cover
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[
Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> AsyncIterator[Message]:
yield {"type": "http.disconnect"}

sent: list[Message] = []

async def send(message: Message) -> None:
sent.append(message)

await app(scope, receive().__anext__, send)

assert events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
"4:STARTED",
"5:STARTED",
"6:STARTED",
"7:STARTED",
"8:STARTED",
"9:STARTED",
"10:STARTED",
"10:COMPLETED",
"9:COMPLETED",
"8:COMPLETED",
"7:COMPLETED",
"6:COMPLETED",
"5:COMPLETED",
"4:COMPLETED",
"3:COMPLETED",
"2:COMPLETED",
"1:COMPLETED",
]

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"0")],
},
{"type": "http.response.body", "body": b"", "more_body": False},
]


@pytest.mark.anyio
@pytest.mark.parametrize("send_body", [True, False])
async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
for _ in range(2):
msg = await receive()
while msg["type"] == "http.request":
msg = await receive()
assert msg["type"] == "http.disconnect"
await Response(b"good!")(scope, receive, send)

class MyMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = MyMiddleware(app_poll_disconnect)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> AsyncIterator[Message]:
# the key here is that we only ever send 1 htt.disconnect message
if send_body:
yield {"type": "http.request", "body": b"hello", "more_body": True}
yield {"type": "http.request", "body": b"", "more_body": False}
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover

sent: list[Message] = []

async def send(message: Message) -> None:
sent.append(message)

await app(scope, receive().__anext__, send)

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"5")],
},
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]