Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Save the OIDC session ID (sid) with the device on login (#11482)
Browse files Browse the repository at this point in the history
As a step towards allowing back-channel logout for OIDC.
  • Loading branch information
sandhose authored Dec 6, 2021
1 parent 8b4b153 commit a15a893
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 65 deletions.
1 change: 1 addition & 0 deletions changelog.d/11482.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Save the OpenID Connect session ID on login.
34 changes: 31 additions & 3 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import bcrypt
import pymacaroons
import unpaddedbase64
from pymacaroons.exceptions import MacaroonVerificationFailedException

from twisted.web.server import Request

Expand Down Expand Up @@ -182,8 +183,11 @@ class LoginTokenAttributes:

user_id = attr.ib(type=str)

# the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
"""The SSO Identity Provider that the user authenticated with, to get this token."""

auth_provider_session_id = attr.ib(type=Optional[str])
"""The session ID advertised by the SSO Identity Provider."""


class AuthHandler:
Expand Down Expand Up @@ -1650,6 +1654,7 @@ async def complete_sso_login(
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Expand All @@ -1665,6 +1670,7 @@ async def complete_sso_login(
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
Expand All @@ -1685,6 +1691,7 @@ async def complete_sso_login(
extra_attributes,
new_user=new_user,
user_profile_data=profile,
auth_provider_session_id=auth_provider_session_id,
)

def _complete_sso_login(
Expand All @@ -1696,6 +1703,7 @@ def _complete_sso_login(
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""
The synchronous portion of complete_sso_login.
Expand All @@ -1717,7 +1725,9 @@ def _complete_sso_login(

# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id, auth_provider_id=auth_provider_id
registered_user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)

# Append the login token to the original redirect URL (i.e. with its query
Expand Down Expand Up @@ -1822,6 +1832,7 @@ def generate_short_term_login_token(
self,
user_id: str,
auth_provider_id: str,
auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
Expand All @@ -1830,6 +1841,10 @@ def generate_short_term_login_token(
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
if auth_provider_session_id is not None:
macaroon.add_first_party_caveat(
"auth_provider_session_id = %s" % (auth_provider_session_id,)
)
return macaroon.serialize()

def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
Expand All @@ -1851,15 +1866,28 @@ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")

auth_provider_session_id: Optional[str] = None
try:
auth_provider_session_id = get_value_from_macaroon(
macaroon, "auth_provider_session_id"
)
except MacaroonVerificationFailedException:
pass

v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)

return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
return LoginTokenAttributes(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)

def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
Expand Down
8 changes: 8 additions & 0 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ async def check_device_registered(
user_id: str,
device_id: Optional[str],
initial_device_display_name: Optional[str] = None,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> str:
"""
If the given device has not been registered, register it with the
Expand All @@ -312,6 +314,8 @@ async def check_device_registered(
user_id: @user:id
device_id: device id supplied by client
initial_device_display_name: device display name from client
auth_provider_id: The SSO IdP the user used, if any.
auth_provider_session_id: The session ID (sid) got from the SSO IdP.
Returns:
device id (generated if none was supplied)
"""
Expand All @@ -323,6 +327,8 @@ async def check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
Expand All @@ -337,6 +343,8 @@ async def check_device_registered(
user_id=user_id,
device_id=new_device_id,
initial_device_display_name=initial_device_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])
Expand Down
58 changes: 35 additions & 23 deletions synapse/handlers/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
Expand Down Expand Up @@ -117,7 +117,8 @@ async def load_metadata(self) -> None:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
await p.load_jwks()
if not p._uses_userinfo:
await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
Expand Down Expand Up @@ -498,10 +499,6 @@ async def load_jwks(self, force: bool = False) -> JWKS:
return await self._jwks.get()

async def _load_jwks(self) -> JWKS:
if self._uses_userinfo:
# We're not using jwt signing, return an empty jwk set
return {"keys": []}

metadata = await self.load_metadata()

# Load the JWKS using the `jwks_uri` metadata.
Expand Down Expand Up @@ -663,7 +660,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo:

return UserInfo(resp)

async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
Expand All @@ -673,7 +670,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
request. This value should match the one inside the token.
Returns:
An object representing the user.
The decoded claims in the ID token.
"""
metadata = await self.load_metadata()
claims_params = {
Expand All @@ -684,9 +681,6 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken

alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
Expand All @@ -703,7 +697,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
Expand All @@ -713,15 +707,16 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
claims = jwt.decode(
id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)

logger.debug("Decoded id_token JWT %r; validating", claims)

claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)

return claims

async def handle_redirect_request(
self,
Expand Down Expand Up @@ -837,22 +832,37 @@ async def handle_oidc_callback(

logger.debug("Successfully obtained OAuth2 token data: %r", token)

# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
# If there is an id_token, it should be validated, regardless of the
# userinfo endpoint is used or not.
if token.get("id_token") is not None:
try:
id_token = await self._parse_id_token(token, nonce=session_data.nonce)
sid = id_token.get("sid")
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
else:
id_token = None
sid = None

# Now that we have a token, get the userinfo either from the `id_token`
# claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
elif id_token is not None:
userinfo = UserInfo(id_token)
else:
try:
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
logger.error("Missing id_token in token response")
self._sso_handler.render_error(
request, "invalid_token", "Missing id_token in token response"
)
return

# first check if we're doing a UIA
if session_data.ui_auth_session_id:
Expand Down Expand Up @@ -884,7 +894,7 @@ async def handle_oidc_callback(
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, session_data.client_redirect_url
userinfo, token, request, session_data.client_redirect_url, sid
)
except MappingException as e:
logger.exception("Could not map user")
Expand All @@ -896,6 +906,7 @@ async def _complete_oidc_login(
token: Token,
request: SynapseRequest,
client_redirect_url: str,
sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow
Expand Down Expand Up @@ -1008,6 +1019,7 @@ async def grandfather_existing_users() -> Optional[str]:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
auth_provider_session_id=sid,
)

def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
Expand Down
15 changes: 12 additions & 3 deletions synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ async def register_device(
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
) -> Tuple[str, str, Optional[int], Optional[str]]:
"""Register a device for a user and generate an access token.
Expand All @@ -756,9 +757,9 @@ async def register_device(
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
auth_provider_id: The SSO IdP the user used, if any.
should_issue_refresh_token: Whether it should also issue a refresh token
auth_provider_session_id: The session ID received during login from the SSO IdP.
Returns:
Tuple of device ID, access token, access token expiration time and refresh token
"""
Expand All @@ -769,6 +770,8 @@ async def register_device(
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)

login_counter.labels(
Expand All @@ -791,6 +794,8 @@ async def register_device_inner(
is_guest: bool = False,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> LoginDict:
"""Helper for register_device
Expand Down Expand Up @@ -822,7 +827,11 @@ class and RegisterDeviceReplicationServlet.
refresh_token_id = None

registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
user_id,
device_id,
initial_display_name,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
if is_guest:
assert access_token_expiry is None
Expand Down
4 changes: 4 additions & 0 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ async def complete_sso_login_request(
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
extra_login_attributes: Optional[JsonDict] = None,
auth_provider_session_id: Optional[str] = None,
) -> None:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
Expand Down Expand Up @@ -415,6 +416,8 @@ async def complete_sso_login_request(
extra_login_attributes: An optional dictionary of extra
attributes to be provided to the client in the login response.
auth_provider_session_id: An optional session ID from the IdP.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: if the mapping provider needs to redirect the user
Expand Down Expand Up @@ -490,6 +493,7 @@ async def complete_sso_login_request(
client_redirect_url,
extra_login_attributes,
new_user=new_user,
auth_provider_session_id=auth_provider_session_id,
)

async def _call_attribute_mapper(
Expand Down
2 changes: 2 additions & 0 deletions synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def generate_short_term_login_token(
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "",
auth_provider_session_id: Optional[str] = None,
) -> str:
"""Generate a login token suitable for m.login.token authentication
Expand All @@ -643,6 +644,7 @@ def generate_short_term_login_token(
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id,
auth_provider_id,
auth_provider_session_id,
duration_in_ms,
)

Expand Down
Loading

0 comments on commit a15a893

Please sign in to comment.