diff --git a/auth_lib/aiomethods.py b/auth_lib/aiomethods.py index 26994b5..25951e0 100644 --- a/auth_lib/aiomethods.py +++ b/auth_lib/aiomethods.py @@ -3,21 +3,25 @@ import aiohttp -from .exceptions import AuthFailed, SessionExpired +from .exceptions import AuthFailed, IncorrectData, NotFound, SessionExpired # See docs on https://api.test.profcomff.com/?urls.primaryName=auth class AsyncAuthLib: - url: str + auth_url: str + userdata_url: str - def __init__(self, url: str): - self.url = url + def __init__(self, *, auth_url: str | None = None, userdata_url: str | None = None): + self.auth_url = auth_url + self.userdata_url = userdata_url async def email_login(self, email: str, password: str) -> dict[str, Any]: json = {"email": email, "password": password} async with aiohttp.ClientSession() as session: - response = await session.post(url=f"{self.url}/email/login", json=json) + response = await session.post( + url=urljoin(self.auth_url, "email/login"), json=json + ) match response.status: case 200: return await response.json() @@ -28,7 +32,7 @@ async def check_token(self, token: str) -> dict[str, Any] | None: headers = {"Authorization": token} async with aiohttp.request( "GET", - urljoin(self.url, "me"), + urljoin(self.auth_url, "me"), headers={"Authorization": token}, params={ "info": [ @@ -45,7 +49,9 @@ async def check_token(self, token: str) -> dict[str, Any] | None: async def logout(self, token: str) -> bool: headers = {"Authorization": token} async with aiohttp.ClientSession() as session: - response = await session.post(url=f"{self.url}/logout", headers=headers) + response = await session.post( + url=urljoin(self.auth_url, "logout"), headers=headers + ) match response.status: case 200: @@ -54,3 +60,13 @@ async def logout(self, token: str) -> bool: raise AuthFailed(response=await response.json()) case 403: raise SessionExpired(response=await response.json()) + + async def get_user_data(self, token: str, user_id: int) -> dict[str | Any] | None: + headers = {"Authorization": token} + async with aiohttp.ClientSession() as session: + response = await session.get( + url=urljoin(self.userdata_url, f"user/{user_id}"), headers=headers + ) + if response.ok: + return await response.json() + return None diff --git a/auth_lib/fastapi.py b/auth_lib/fastapi.py index d61cfad..7555840 100644 --- a/auth_lib/fastapi.py +++ b/auth_lib/fastapi.py @@ -14,13 +14,15 @@ class UnionAuthSettings(BaseSettings): AUTH_URL: str = "https://api.test.profcomff.com/auth/" + USERDATA_URL: str = "https://api.test.profcomff.com/userdata/" AUTH_AUTO_ERROR: bool = True AUTH_ALLOW_NONE: bool = False + ENABLE_USERDATA: bool = False model_config = ConfigDict(case_sensitive=True, env_file=".env", extra="ignore") class UnionAuth(SecurityBase): - model = APIKey.construct(in_=APIKeyIn.header, name="Authorization") + model = APIKey.model_construct(in_=APIKeyIn.header, name="Authorization") scheme_name = "token" settings = UnionAuthSettings() @@ -29,23 +31,38 @@ def __init__( scopes: list[str] = [], auto_error: bool | None = None, allow_none: bool | None = None, + enable_userdata: bool | None = None, auth_url=None, # Для обратной совместимости + userdata_url=None, ) -> None: if auth_url is not None: warn( "auth_url in args deprecated, use AUTH_URL env instead", DeprecationWarning, ) + if userdata_url is not None: + warn( + "userdata_url in args deprecated, use USERDATA_URL env instead", + DeprecationWarning, + ) super().__init__() self.auth_url = auth_url or self.settings.AUTH_URL if not self.auth_url.endswith("/"): self.auth_url = self.auth_url + "/" + self.userdata_url = userdata_url or self.settings.USERDATA_URL + if not self.userdata_url.endswith("/"): + self.userdata_url = self.userdata_url + "/" self.auto_error = ( auto_error if auto_error is not None else self.settings.AUTH_AUTO_ERROR ) self.allow_none = ( allow_none if allow_none is not None else self.settings.AUTH_ALLOW_NONE ) + self.enable_userdata = ( + enable_userdata + if enable_userdata is not None + else self.settings.ENABLE_USERDATA + ) self.scopes = scopes def _except(self): @@ -61,20 +78,38 @@ async def _get_session(self, token: str | None) -> dict[str, Any] | None: return None if not token: return self._except() - return await AsyncAuthLib(url=self.auth_url).check_token(token) + return await AsyncAuthLib(auth_url=self.auth_url).check_token(token) + + async def _get_userdata( + self, token: str | None, user_id: int + ) -> dict[str, Any] | None: + if not token and self.allow_none: + return None + if not token: + return self._except() + if self.enable_userdata: + return await AsyncAuthLib(userdata_url=self.userdata_url).get_user_data( + token, user_id + ) + return None async def __call__( self, request: Request, ) -> dict[str, Any] | None: token = request.headers.get("Authorization") - user_session = await self._get_session(token) - if user_session is None: + result = await self._get_session(token) + if result is None: return self._except() + if self.enable_userdata: + user_data_info = await self._get_userdata(token, result["id"]) + result["userdata"] = [] + if user_data_info is not None: + result["userdata"] = user_data_info["items"] session_scopes = set( - [scope["name"].lower() for scope in user_session["session_scopes"]] + [scope["name"].lower() for scope in result["session_scopes"]] ) required_scopes = set([scope.lower() for scope in self.scopes]) if required_scopes - session_scopes: self._except() - return user_session + return result diff --git a/auth_lib/methods.py b/auth_lib/methods.py index f9b9c02..48cbe67 100644 --- a/auth_lib/methods.py +++ b/auth_lib/methods.py @@ -3,20 +3,22 @@ import requests -from .exceptions import AuthFailed, SessionExpired +from .exceptions import AuthFailed, IncorrectData, NotFound, SessionExpired # See docs on https://api.test.profcomff.com/?urls.primaryName=auth class AuthLib: - url: str + auth_url: str + userdata_url: str - def __init__(self, url: str): - self.url = url + def __init__(self, *, auth_url: str | None = None, userdata_url: str | None = None): + self.auth_url = auth_url + self.userdata_url = userdata_url def email_login(self, email: str, password: str) -> dict[str, Any]: json = {"email": email, "password": password} - response = requests.post(url=f"{self.url}/email/login", json=json) + response = requests.post(url=urljoin(self.auth_url, "email/login"), json=json) match response.status_code: case 200: return response.json() @@ -26,7 +28,7 @@ def email_login(self, email: str, password: str) -> dict[str, Any]: def check_token(self, token: str) -> dict[str, Any] | None: headers = {"Authorization": token} response = requests.get( - url=urljoin(self.url, "me"), + url=urljoin(self.auth_url, "me"), headers=headers, params={ "info": [ @@ -41,8 +43,7 @@ def check_token(self, token: str) -> dict[str, Any] | None: def logout(self, token: str) -> bool: headers = {"Authorization": token} - response = requests.post(url=f"{self.url}/logout", headers=headers) - + response = requests.post(url=urljoin(self.auth_url, "logout"), headers=headers) match response.status_code: case 200: return True @@ -50,3 +51,12 @@ def logout(self, token: str) -> bool: raise AuthFailed(response=response.json()["body"]) case 403: raise SessionExpired(response=response.json()["body"]) + + def get_user_data(self, token: str, user_id: int) -> dict[str | Any] | None: + headers = {"Authorization": token} + response = requests.get( + url=urljoin(self.userdata_url, f"user/{user_id}"), headers=headers + ) + if response.ok: + return response.json() + return None diff --git a/setup.py b/setup.py index 043430f..f959531 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import find_packages, setup -with open("README.md", "r") as readme_file: +with open("README.md", "r", encoding="utf-8") as readme_file: readme = readme_file.read() setup(