Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment a high-level HTTPMiddleware #1691

Closed
wants to merge 13 commits into from
22 changes: 22 additions & 0 deletions starlette/_compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import hashlib
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
import sys
from typing import Any, AsyncContextManager

__all__ = [
"md5_hexdigest",
"aclosing",
]

# Compat wrapper to always include the `usedforsecurity=...` parameter,
# which is only added from Python 3.9 onwards.
Expand Down Expand Up @@ -27,3 +34,18 @@ def md5_hexdigest(

def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str:
return hashlib.md5(data).hexdigest()


if sys.version_info >= (3, 10): # pragma: no cover
from contextlib import aclosing
else: # pragma: no cover

class aclosing(AsyncContextManager):
def __init__(self, thing: Any) -> None:
self.thing = thing

async def __aenter__(self) -> Any:
return self.thing

async def __aexit__(self, *exc_info: Any) -> None:
await self.thing.aclose()
170 changes: 170 additions & 0 deletions starlette/middleware/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Any, AsyncGenerator, Callable, Optional, Union

from .._compat import aclosing
from ..datastructures import Headers
from ..requests import HTTPConnection
from ..responses import Response
from ..types import ASGIApp, Message, Receive, Scope, Send

# This type hint not exposed, as it exists mostly for our own documentation purposes.
# End users should use one of these type hints explicitly when overriding '.dispatch()'.
_DispatchFlow = Union[
# Default case:
# response = yield
AsyncGenerator[None, Response],
# Early response and/or error handling:
# if condition:
# yield Response(...)
# return
# try:
# response = yield None
# except Exception:
# yield Response(...)
# else:
# ...
AsyncGenerator[Optional[Response], Response],
]


class HTTPMiddleware:
def __init__(
self,
app: ASGIApp,
dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None,
) -> None:
self.app = app
self._dispatch_func = self.dispatch if dispatch is None else dispatch

def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow:
raise NotImplementedError(
"No dispatch implementation was given. "
"Either pass 'dispatch=...' to HTTPMiddleware, "
"or subclass HTTPMiddleware and override the 'dispatch()' method."
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

conn = HTTPConnection(scope)

async with aclosing(self._dispatch_func(conn)) as flow:
# Kick the flow until the first `yield`.
# Might respond early before we call into the app.
maybe_early_response = await flow.__anext__()

if maybe_early_response is not None:
try:
await flow.__anext__()
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")

await maybe_early_response(scope, receive, send)
return

response_started = False

async def wrapped_send(message: Message) -> None:
nonlocal response_started

if message["type"] == "http.response.start":
response_started = True

headers = Headers(raw=message["headers"])
response = _StubResponse(
status_code=message["status"],
media_type=headers.get("content-type"),
)
response.raw_headers = headers.raw

try:
await flow.asend(response)
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")

message["headers"] = response.raw_headers

await send(message)

try:
await self.app(scope, receive, wrapped_send)
except Exception as exc:
if response_started:
raise

try:
response = await flow.athrow(exc)
except StopAsyncIteration:
response = None
except Exception:
# Exception was not handled, or they raised another one.
raise

if response is None:
raise RuntimeError(
f"dispatch() handled exception {exc!r}, "
"but no response was returned"
)

await response(scope, receive, send)
return


# This customized stub response helps prevent users from shooting themselves
# in the foot, doing things that don't actually have any effect.


class _StubResponse(Response):
def __init__(self, status_code: int, media_type: Optional[str] = None) -> None:
self._status_code = status_code
self._media_type = media_type
self.raw_headers = []

@property # type: ignore
def status_code(self) -> int: # type: ignore
return self._status_code

@status_code.setter
def status_code(self, value: Any) -> None:
raise RuntimeError(
"Setting .status_code in HTTPMiddleware is not supported. "
"If you're writing middleware that requires modifying the response "
"status code or sending another response altogether, please consider "
"writing pure ASGI middleware. "
"See: https://starlette.io/middleware/#pure-asgi-middleware"
)

@property # type: ignore
def media_type(self) -> Optional[str]: # type: ignore
return self._media_type

@media_type.setter
def media_type(self, value: Any) -> None:
raise RuntimeError(
"Setting .media_type in HTTPMiddleware is not supported, as it has "
"no effect. If you do need to tweak the response "
"content type, consider: response.headers['Content-Type'] = ..."
)

@property # type: ignore
def body(self) -> bytes: # type: ignore
raise RuntimeError(
"Accessing the response body in HTTPMiddleware is not supported. "
"If you're writing middleware that requires peeking into the response "
"body, please consider writing pure ASGI middleware and wrapping send(). "
"See: https://starlette.io/middleware/#pure-asgi-middleware"
)

@body.setter
def body(self, body: bytes) -> None:
raise RuntimeError(
"Setting the response body in HTTPMiddleware is not supported."
"If you're writing middleware that requires modifying the response "
"body, please consider writing pure ASGI middleware and wrapping send(). "
"See: https://starlette.io/middleware/#pure-asgi-middleware"
)
Loading