From 43fc39ad7dc4cbab55e4a51fcd919b322557c759 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Jul 2023 10:42:46 +0200 Subject: [PATCH] Try to make Request generic over State --- starlette/requests.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/starlette/requests.py b/starlette/requests.py index fff451e23..1b15b1f03 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -3,6 +3,7 @@ from http import cookies as http_cookies import anyio +import typing_extensions from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State @@ -60,15 +61,27 @@ class ClientDisconnect(Exception): pass -class HTTPConnection(typing.Mapping[str, typing.Any]): +_StateType = typing_extensions.TypeVar("_StateType", default=State) + + +class HTTPConnection(typing.Mapping[str, typing.Any], typing.Generic[_StateType]): """ A base class for incoming HTTP connections, that is used to provide any functionality that is common to both `Request` and `WebSocket`. """ - def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None: + def __init__( + self, + scope: Scope, + receive: typing.Optional[Receive] = None, + *, + state_factory: typing.Callable[ + [typing.Dict[str, typing.Any]], _StateType + ] = State, + ) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope + self._state_factory = state_factory def __getitem__(self, key: str) -> typing.Any: return self.scope[key] @@ -164,13 +177,13 @@ def user(self) -> typing.Any: return self.scope["user"] @property - def state(self) -> State: + def state(self) -> _StateType: if not hasattr(self, "_state"): # Ensure 'state' has an empty dict if it's not already populated. self.scope.setdefault("state", {}) # Create a state instance with a reference to the dict in which it should # store info - self._state = State(self.scope["state"]) + self._state = self._state_factory(self.scope["state"]) return self._state def url_for(self, name: str, /, **path_params: typing.Any) -> URL: