diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7eb07e9a5523..55aa73944b24 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -285,7 +285,7 @@ def generate_config_section(self, generate_secrets=False, **kwargs): # By default, this is infinite. # #session_lifetime: 24h - + # MSC2918 # TODO: docs access_token_lifetime: 5m diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0575fbc779a0..3a6251d0543e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -781,7 +781,10 @@ async def refresh_token( if existing_token is None: raise SynapseError(400, "refresh token does not exist") - if existing_token.has_next_access_token_been_used or existing_token.has_next_refresh_token_been_refreshed: + if ( + existing_token.has_next_access_token_been_used + or existing_token.has_next_refresh_token_been_refreshed + ): raise SynapseError(400, "refresh token isn't valid anymore") ( @@ -796,7 +799,9 @@ async def refresh_token( valid_until_ms=valid_until_ms, refresh_token_id=new_refresh_token_id, ) - await self.store.replace_refresh_token(existing_token.token_id, new_refresh_token_id) + await self.store.replace_refresh_token( + existing_token.token_id, new_refresh_token_id + ) return access_token, new_refresh_token async def get_refresh_token_for_user_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 94a035b5d736..a2d7612b29cf 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -691,6 +691,7 @@ async def register_device( is_guest: Whether this is a guest account auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: Whether it should also issue a refresh token Returns: Tuple of device ID, access token, access token expiration time and refresh token """ diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index ab8a91db6514..e76010ce27a5 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -15,7 +15,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from typing_extensions import TypedDict @@ -158,11 +158,12 @@ def on_GET(self, request: SynapseRequest): async def on_POST(self, request: SynapseRequest): login_submission = parse_json_object_from_request(request) - refresh_token_param = cast( - List[bytes], - request.args.get(bytes(LoginRestServlet.REFRESH_TOKEN_PARAM, "utf-8"), []), + param = bytes(LoginRestServlet.REFRESH_TOKEN_PARAM, "utf-8") + should_issue_refresh_token = ( + request.args is not None + and param in request.args + and request.args[param][0] == b"true" ) - should_issue_refresh_token = any((i == b"true" for i in refresh_token_param)) try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 96830162c4ed..00a4e85dfadd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -724,7 +724,12 @@ async def _do_guest_registration(self, params, address=None): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token, valid_until_ms, refresh_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ec9a1d84ab5d..1f53fb940fb2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1052,12 +1052,14 @@ async def update_access_token_last_validated(self, token_id: int) -> None: desc="update_access_token_last_validated", ) - async def lookup_refresh_token(self, token: str) -> Optional[RefreshTokenLookupResult]: - """Lookup a refresh token with hints about its validity. - """ + async def lookup_refresh_token( + self, token: str + ) -> Optional[RefreshTokenLookupResult]: + """Lookup a refresh token with hints about its validity.""" def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: - txn.execute(""" + txn.execute( + """ SELECT rt.id token_id, rt.user_id, @@ -1069,7 +1071,9 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id WHERE rt.token = ? - """, (token,)) + """, + (token,), + ) row = txn.fetchone() if row is None: @@ -1084,7 +1088,9 @@ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: has_next_access_token_been_used=row[5], ) - return await self.db_pool.runInteraction("lookup_refresh_token", _lookup_refresh_token_txn) + return await self.db_pool.runInteraction( + "lookup_refresh_token", _lookup_refresh_token_txn + ) async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: await self.db_pool.simple_update_one(