Skip to content

Commit

Permalink
Use Protocol instead of Callable.
Browse files Browse the repository at this point in the history
  • Loading branch information
Paweł Rubin authored and pawelrubin committed Dec 20, 2023
1 parent 0d36db7 commit 2d7eb8c
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 21 deletions.
6 changes: 3 additions & 3 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import typing
import warnings

from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from starlette.datastructures import State, URLPath
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
Expand Down Expand Up @@ -129,7 +129,7 @@ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:

def add_middleware(
self,
middleware_class: typing.Callable[Concatenate[ASGIApp, P], typing.Any],
middleware_class: typing.Type[_MiddlewareClass[P]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand Down
16 changes: 12 additions & 4 deletions starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
from typing import Any, Callable, Iterator
from typing import Any, Iterator, Protocol, Type

from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec

from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send

P = ParamSpec("P")


class _MiddlewareClass(Protocol[P]):
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
... # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
... # pragma: no cover


class Middleware:
def __init__(
self,
cls: Callable[Concatenate[ASGIApp, P], Any],
cls: Type[_MiddlewareClass[P]],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import contextvars
from contextlib import AsyncExitStack
from typing import AsyncGenerator, Awaitable, Callable, List, Union
from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union

import anyio
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
Expand Down Expand Up @@ -196,7 +196,7 @@ async def dispatch(self, request, call_next):
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
Expand Down
13 changes: 8 additions & 5 deletions tests/middleware/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from starlette.middleware import Middleware
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware:
def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None: # pragma: no cover
class CustomMiddleware: # pragma: no cover
def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
self.app = app
self.foo = foo
self.bar = bar

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

def test_middleware_repr():

def test_middleware_repr() -> None:
middleware = Middleware(CustomMiddleware, "foo", bar=123)
assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"


def test_middleware_iter():
def test_middleware_iter() -> None:
cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})
12 changes: 6 additions & 6 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Callable
from typing import AsyncIterator, Callable

import anyio
import httpx
Expand All @@ -15,7 +15,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -499,8 +499,8 @@ class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, *args: Any):
await self.app(*args)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

class SimpleInitializableMiddleware:
counter = 0
Expand All @@ -509,8 +509,8 @@ def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1

async def __call__(self, *args: Any):
await self.app(*args)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)

def get_app() -> ASGIApp:
app = Starlette()
Expand Down

0 comments on commit 2d7eb8c

Please sign in to comment.