Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix destination_is errors seen in sentry. #13041

Merged
merged 8 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.d/13041.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a bug introduced in Synapse 1.58 where profile requests for a malformed user ID would ccause an internal error. Synapse now returns 400 Bad Request in this situation.

7 changes: 5 additions & 2 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,11 @@ def build_auth_headers(
Returns:
A list of headers to be added as "Authorization:" headers
"""
if destination is None and destination_is is None:
raise ValueError("destination and destination_is cannot both be None!")
if not destination and not destination_is:
raise ValueError(
"At least one of the arguments destination and destination_is "
"must be a nonempty bytestring."
)

request: JsonDict = {
"method": method.decode("ascii"),
Expand Down
20 changes: 16 additions & 4 deletions synapse/rest/client/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

""" This module contains REST servlets to do with profile: /profile/<paths> """

from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import Codes, SynapseError
Expand Down Expand Up @@ -45,8 +45,12 @@ async def on_GET(
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user

user = UserID.from_string(user_id)
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user)

displayname = await self.profile_handler.get_displayname(user)
Expand Down Expand Up @@ -98,8 +102,12 @@ async def on_GET(
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user

user = UserID.from_string(user_id)
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user)

avatar_url = await self.profile_handler.get_avatar_url(user)
Expand Down Expand Up @@ -150,8 +158,12 @@ async def on_GET(
requester = await self.auth.get_user_by_req(request)
requester_user = requester.user

user = UserID.from_string(user_id)
if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)

user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user)

displayname = await self.profile_handler.get_displayname(user)
Expand Down
3 changes: 2 additions & 1 deletion synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def from_string(cls: Type[DS], s: str) -> DS:
)

domain = parts[1]

# This code will need changing if we want to support multiple domain
# names on one HS
return cls(localpart=parts[0], domain=domain)
Expand All @@ -279,6 +278,8 @@ def to_string(self) -> str:
@classmethod
def is_valid(cls: Type[DS], s: str) -> bool:
"""Parses the input string and attempts to ensure it is valid."""
# TODO: this does not reject an empty localpart or an overly-long string.
# See https://spec.matrix.org/v1.2/appendices/#identifier-grammar
try:
obj = cls.from_string(s)
# Apply additional validation to the domain. This is only done
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,17 @@ def test_too_big(self):
self.assertIsInstance(f.value, RequestSendFailed)

self.assertTrue(transport.disconnecting)

def test_build_auth_headers_rejects_falsey_destinations(self) -> None:
with self.assertRaises(ValueError):
self.cl.build_auth_headers(None, b"GET", b"https://example.com")
with self.assertRaises(ValueError):
self.cl.build_auth_headers(b"", b"GET", b"https://example.com")
with self.assertRaises(ValueError):
self.cl.build_auth_headers(
None, b"GET", b"https://example.com", destination_is=b""
)
with self.assertRaises(ValueError):
self.cl.build_auth_headers(
b"", b"GET", b"https://example.com", destination_is=b""
)
8 changes: 8 additions & 0 deletions tests/rest/client/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

"""Tests REST events for /profile paths."""
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional

from twisted.test.proto_helpers import MemoryReactor
Expand Down Expand Up @@ -49,6 +51,12 @@ def test_get_displayname(self) -> None:
res = self._get_displayname()
self.assertEqual(res, "owner")

def test_get_displayname_rejects_bad_username(self) -> None:
channel = self.make_request(
"GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)

def test_set_displayname(self) -> None:
channel = self.make_request(
"PUT",
Expand Down
13 changes: 12 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,21 @@ def test_parse(self):
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))

def test_pase_empty(self):
def test_parse_rejects_empty_id(self):
with self.assertRaises(SynapseError):
UserID.from_string("")

def test_parse_rejects_missing_sigil(self):
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")

def test_parse_rejects_missing_separator(self):
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")

def test_validation_rejects_missing_domain(self):
self.assertFalse(UserID.is_valid("@alice:"))

def test_build(self):
user = UserID("5678efgh", "my.domain")

Expand Down