Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add *args to Middleware and improve its type hints #2381

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing
import warnings

Expand All @@ -14,7 +15,14 @@
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket

if sys.version_info >= (3, 10): # pragma: no cover
from typing import Concatenate, ParamSpec
else: # pragma: no cover
from typing_extensions import Concatenate, ParamSpec


AppType = typing.TypeVar("AppType", bound="Starlette")
P = ParamSpec("P")


class Starlette:
Expand Down Expand Up @@ -124,10 +132,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
self.router.host(host, app=app, name=name) # pragma: no cover

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
def add_middleware(
self,
middleware_class: typing.Callable[Concatenate[ASGIApp, P], typing.Any],
Kludex marked this conversation as resolved.
Show resolved Hide resolved
*args: P.args,
pawelrubin marked this conversation as resolved.
Show resolved Hide resolved
**options: P.kwargs,
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.user_middleware.insert(0, Middleware(middleware_class, *args, **options))

def add_exception_handler(
self,
Expand Down
22 changes: 19 additions & 3 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import typing
import sys
from typing import Any, Callable, Iterator

from starlette.types import ASGIApp

if sys.version_info >= (3, 10): # pragma: no cover
from typing import Concatenate, ParamSpec
else: # pragma: no cover
from typing_extensions import Concatenate, ParamSpec


P = ParamSpec("P")


class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(
self,
cls: Callable[Concatenate[ASGIApp, P], Any],
Kludex marked this conversation as resolved.
Show resolved Hide resolved
*args: P.args,
pawelrubin marked this conversation as resolved.
Show resolved Hide resolved
**options: P.kwargs,
) -> None:
self.cls = cls
self.options = options

def __iter__(self) -> typing.Iterator[typing.Any]:
def __iter__(self) -> Iterator[Any]:
as_tuple = (self.cls, self.options)
return iter(as_tuple)

Expand Down
4 changes: 3 additions & 1 deletion tests/middleware/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from starlette.middleware import Middleware
from starlette.types import ASGIApp


class CustomMiddleware:
pass
def __init__(self, app: ASGIApp) -> None:
self.app = app # pragma: no cover


def test_middleware_repr():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.responses import JSONResponse
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocketDisconnect
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_authentication_redirect(test_client_factory):
assert response.json() == {"authenticated": True, "user": "tomchristie"}


def on_auth_error(request: Request, exc: Exception):
def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
return JSONResponse({"error": str(exc)}, status_code=401)


Expand Down