Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
MSC2918: do not invalidate refresh token immediately & fix tests
Browse files Browse the repository at this point in the history
This checks for child token usage to validate the refresh token validity.
This means that a token can be refreshed multiple times until one of the child tokens gets used.

Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
  • Loading branch information
sandhose committed Apr 9, 2021
1 parent 324d7bf commit 450a962
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 36 deletions.
4 changes: 4 additions & 0 deletions docs/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,10 @@ account_validity:
#
#session_lifetime: 24h

# MSC2918
# TODO: docs
access_token_lifetime: 5m

# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:
Expand Down
8 changes: 8 additions & 0 deletions synapse/config/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def read_config(self, config, **kwargs):
session_lifetime = self.parse_duration(session_lifetime)
self.session_lifetime = session_lifetime

access_token_lifetime = config.get("access_token_lifetime", "5m")
access_token_lifetime = self.parse_duration(access_token_lifetime)
self.access_token_lifetime = access_token_lifetime # type: int

# The success template used during fallback auth.
self.fallback_success_template = self.read_template("auth_success.html")

Expand Down Expand Up @@ -281,6 +285,10 @@ def generate_config_section(self, generate_secrets=False, **kwargs):
# By default, this is infinite.
#
#session_lifetime: 24h
# MSC2918
# TODO: docs
access_token_lifetime: 5m
# The user must provide all of the below types of 3PID when registering.
#
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ async def refresh_token(
if existing_token is None:
raise SynapseError(400, "refresh token does not exist")

if not existing_token.valid:
if existing_token.has_next_access_token_been_used or existing_token.has_next_refresh_token_been_refreshed:
raise SynapseError(400, "refresh token isn't valid anymore")

(
Expand All @@ -796,7 +796,7 @@ async def refresh_token(
valid_until_ms=valid_until_ms,
refresh_token_id=new_refresh_token_id,
)
await self.store.invalidate_refresh_token(existing_token.token_id)
await self.store.replace_refresh_token(existing_token.token_id, new_refresh_token_id)
return access_token, new_refresh_token

async def get_refresh_token_for_user_id(
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self, hs: "HomeServer"):
self.pusher_pool = hs.get_pusherpool()

self.session_lifetime = hs.config.session_lifetime
self.access_token_lifetime = hs.config.access_token_lifetime

async def check_username(
self,
Expand Down Expand Up @@ -757,7 +758,7 @@ class and RegisterDeviceReplicationServlet.
user_id,
device_id=registered_device_id,
)
valid_until_ms = self.clock.time_msec() + 60 * 1000
valid_until_ms = self.clock.time_msec() + self.access_token_lifetime

access_token = await self._auth_handler.get_access_token_for_user_id(
user_id,
Expand Down
26 changes: 16 additions & 10 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,19 +724,25 @@ async def _do_guest_registration(self, params, address=None):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
device_id, access_token, valid_until_ms, refresh_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True
)

return (
200,
{
"user_id": user_id,
"device_id": device_id,
"access_token": access_token,
"home_server": self.hs.hostname,
},
)
result = {
"user_id": user_id,
"device_id": device_id,
"access_token": access_token,
"home_server": self.hs.hostname,
}

if valid_until_ms is not None:
expires_in_ms = valid_until_ms - self.clock.time_msec()
result["expires_in"] = int(expires_in_ms / 1000)

if refresh_token is not None:
result["refresh_token"] = refresh_token

return 200, result


def _calculate_registration_flows(
Expand Down
58 changes: 38 additions & 20 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ class RefreshTokenLookupResult:
user_id = attr.ib(type=str)
device_id = attr.ib(type=str)
token_id = attr.ib(type=int)
valid = attr.ib(type=bool)
next_token_id = attr.ib(type=int)
has_next_refresh_token_been_refreshed = attr.ib(type=bool)
has_next_access_token_been_used = attr.ib(type=bool)


class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Expand Down Expand Up @@ -1050,30 +1052,46 @@ async def update_access_token_last_validated(self, token_id: int) -> None:
desc="update_access_token_last_validated",
)

async def lookup_refresh_token(
self, token: str
) -> Optional[RefreshTokenLookupResult]:
d = await self.db_pool.simple_select_one(
"refresh_tokens",
{"token": token},
["id", "user_id", "device_id", "valid"],
allow_none=True,
desc="lookup_refresh_token",
)
async def lookup_refresh_token(self, token: str) -> Optional[RefreshTokenLookupResult]:
"""Lookup a refresh token with hints about its validity.
"""

if d is not None:
d["token_id"] = d["id"]
del d["id"]
return RefreshTokenLookupResult(**d)
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
txn.execute("""
SELECT
rt.id token_id,
rt.user_id,
rt.device_id,
rt.next_token_id,
(nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
(at.last_validated IS NOT NULL) has_next_access_token_been_used
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
WHERE rt.token = ?
""", (token,))
row = txn.fetchone()

if row is None:
return None

return None
return RefreshTokenLookupResult(
token_id=row[0],
user_id=row[1],
device_id=row[2],
next_token_id=row[3],
has_next_refresh_token_been_refreshed=row[4],
has_next_access_token_been_used=row[5],
)

return await self.db_pool.runInteraction("lookup_refresh_token", _lookup_refresh_token_txn)

async def invalidate_refresh_token(self, token_id: int) -> None:
async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
await self.db_pool.simple_update_one(
"refresh_tokens",
{"id": token_id},
{"valid": False},
desc="invalidate_refresh_token",
{"next_token_id": next_token_id},
desc="replace_refresh_token",
)


Expand Down Expand Up @@ -1325,7 +1343,7 @@ async def add_refresh_token_to_user(
"user_id": user_id,
"device_id": device_id,
"token": token,
"valid": True,
"next_token_id": None,
},
desc="add_access_token_to_user",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CREATE TABLE refresh_tokens (
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
valid BOOLEAN NOT NULL,
replaced_by BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE,
UNIQUE(token)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CREATE TABLE refresh_tokens (
user_id TEXT NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
valid BOOLEAN NOT NULL,
replaced_by BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE,
UNIQUE(token)
);

Expand Down
1 change: 1 addition & 0 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def test_cannot_use_regular_token_as_guest(self):
device_id="DEVICE",
valid_until_ms=None,
puppets_user_id=None,
refresh_token_id=None,
)

async def get_user(tok):
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_dehydrate_and_rehydrate_device(self):
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})

# Create a new login for the user and dehydrated the device
device_id, access_token = self.get_success(
device_id, access_token, _expiration_time, _refresh_token = self.get_success(
self.registration.register_device(
user_id=user_id,
device_id=None,
Expand Down

0 comments on commit 450a962

Please sign in to comment.