Skip to content

Commit

Permalink
Add middleware per Router
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Nov 28, 2023
1 parent 1fd4b20 commit de102eb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
10 changes: 10 additions & 0 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@ def __init__(
# the generic to Lifespan[AppType] is the type of the top level application
# which the router cannot know statically, so we use typing.Any
lifespan: typing.Optional[Lifespan[typing.Any]] = None,
*,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
Expand Down Expand Up @@ -650,6 +652,11 @@ def __init__(
else:
self.lifespan_context = lifespan

self.middleware_stack = self.app
if middleware:
for cls, options in reversed(middleware):
self.middleware_stack = cls(self.middleware_stack, **options)

async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
websocket_close = WebSocketClose()
Expand Down Expand Up @@ -726,6 +733,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The main entry point to the Router class.
"""
await self.middleware_stack(scope, receive, send)

async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
assert scope["type"] in ("http", "websocket", "lifespan")

if "router" not in scope:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,26 @@ def test_router_add_websocket_route(client):
assert text == "Hello, test!"


def test_router_middleware(test_client_factory: typing.Callable[..., TestClient]):
class CustomMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
response = PlainTextResponse("OK")
await response(scope, receive, send)

app = Router(
routes=[Route("/", homepage)],
middleware=[Middleware(CustomMiddleware)],
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "OK"


def http_endpoint(request):
url = request.url_for("http_endpoint")
return Response(f"URL: {url}", media_type="text/plain")
Expand Down

0 comments on commit de102eb

Please sign in to comment.