diff --git a/changelog.d/9436.bugfix b/changelog.d/9436.bugfix new file mode 100644 index 000000000000..a530516eed4e --- /dev/null +++ b/changelog.d/9436.bugfix @@ -0,0 +1 @@ +Fix a bug in single sign-on which could cause a "No session cookie found" error. diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py index c658862fe65f..142b007d010e 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import re +from typing import Union -from twisted.internet import task +from twisted.internet import address, task from twisted.web.client import FileBodyProducer from twisted.web.iweb import IRequest @@ -53,6 +54,40 @@ def stopProducing(self): pass +def get_request_uri(request: IRequest) -> bytes: + """Return the full URI that was requested by the client""" + return b"%s://%s%s" % ( + b"https" if request.isSecure() else b"http", + _get_requested_host(request), + # despite its name, "request.uri" is only the path and query-string. + request.uri, + ) + + +def _get_requested_host(request: IRequest) -> bytes: + hostname = request.getHeader(b"host") + if hostname: + return hostname + + # no Host header, use the address/port that the request arrived on + host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] + + hostname = host.host.encode("ascii") + + if request.isSecure() and host.port == 443: + # default port for https + return hostname + + if not request.isSecure() and host.port == 80: + # default port for http + return hostname + + return b"%s:%i" % ( + hostname, + host.port, + ) + + def get_request_user_agent(request: IRequest, default: str = "") -> str: """Return the last User-Agent header, or the given default.""" # There could be raw utf-8 bytes in the User-Agent header. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6e2fbedd99bf..925edfc40239 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -20,6 +20,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService from synapse.handlers.sso import SsoIdentityProvider +from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, @@ -354,6 +355,7 @@ def __init__(self, hs: "HomeServer"): hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._public_baseurl = hs.config.public_baseurl def register(self, http_server: HttpServer) -> None: super().register(http_server) @@ -373,6 +375,32 @@ def register(self, http_server: HttpServer) -> None: async def on_GET( self, request: SynapseRequest, idp_id: Optional[str] = None ) -> None: + if not self._public_baseurl: + raise SynapseError(400, "SSO requires a valid public_baseurl") + + # if this isn't the expected hostname, redirect to the right one, so that we + # get our cookies back. + requested_uri = get_request_uri(request) + baseurl_bytes = self._public_baseurl.encode("utf-8") + if not requested_uri.startswith(baseurl_bytes): + # swap out the incorrect base URL for the right one. + # + # The idea here is to redirect from + # https://foo.bar/whatever/_matrix/... + # to + # https://public.baseurl/_matrix/... + # + i = requested_uri.index(b"/_matrix") + new_uri = baseurl_bytes[:-1] + requested_uri[i:] + logger.info( + "Requested URI %s is not canonical: redirecting to %s", + requested_uri.decode("utf-8", errors="replace"), + new_uri.decode("utf-8", errors="replace"), + ) + request.redirect(new_uri) + finish_request(request) + return + client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index fb29eaed6f08..744d8d0941bd 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,7 +15,7 @@ import time import urllib.parse -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlencode from mock import Mock @@ -47,8 +47,14 @@ HAS_JWT = False -# public_base_url used in some tests -BASE_URL = "https://synapse/" +# synapse server name: used to populate public_baseurl in some tests +SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" + +# public_baseurl for some tests. It uses an http:// scheme because +# FakeChannel.isSecure() returns False, so synapse will see the requested uri as +# http://..., so using http in the public_baseurl stops Synapse trying to redirect to +# https://.... +BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) # CAS server used in some tests CAS_SERVER = "https://fake.test" @@ -480,11 +486,7 @@ def test_get_msc2858_login_flows(self): def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker - channel = self.make_request( - "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(False, None) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -628,34 +630,21 @@ def test_multi_sso_redirect_to_unknown(self): def test_client_idp_redirect_msc2858_disabled(self): """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) + channel = self._make_sso_redirect_request(True, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" - channel = self.make_request( - "GET", - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), - ) - + channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) @@ -663,6 +652,30 @@ def test_client_idp_redirect_to_oidc(self): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def _make_sso_redirect_request( + self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None + ): + """Send a request to /_matrix/client/r0/login/sso/redirect + + ... or the unstable equivalent + + ... possibly specifying an IDP provider + """ + endpoint = ( + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" + if unstable_endpoint + else "/_matrix/client/r0/login/sso/redirect" + ) + if idp_prov is not None: + endpoint += "/" + idp_prov + endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + return self.make_request( + "GET", + endpoint, + custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], + ) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8231a423f336..946740aa5d51 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -542,13 +542,30 @@ def initiate_sso_login( if client_redirect_url: params["redirectUrl"] = client_redirect_url - # hit the redirect url (which will issue a cookie and state) + # hit the redirect url (which should redirect back to the redirect url. This + # is the easiest way of figuring out what the Host header ought to be set to + # to keep Synapse happy. channel = make_request( self.hs.get_reactor(), self.site, "GET", "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), ) + assert channel.code == 302 + + # hit the redirect url again with the right Host header, which should now issue + # a cookie and redirect to the SSO provider. + location = channel.headers.getRawHeaders("Location")[0] + parts = urllib.parse.urlsplit(location) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + urllib.parse.urlunsplit(("", "") + parts[2:]), + custom_headers=[ + ("Host", parts[1]), + ], + ) assert channel.code == 302 channel.extract_cookies(cookies) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index c26ad824f7a4..9734a2159a1a 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase): def default_config(self): config = super().default_config() - config["public_baseurl"] = "https://synapse.test" + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" if HAS_OIDC: # we enable OIDC as a way of testing SSO flows diff --git a/tests/server.py b/tests/server.py index d4ece5c448ac..939a0008ca2e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -124,7 +124,11 @@ def getPeer(self): return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): - return None + # this is called by Request.__init__ to configure Request.host. + return address.IPv4Address("TCP", "127.0.0.1", 8888) + + def isSecure(self): + return False @property def transport(self):