Skip to content

Commit

Permalink
Fix a race when registering via email 3pid
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul committed Jan 19, 2024
1 parent 2927008 commit c2cfa8b
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 2 deletions.
1 change: 1 addition & 0 deletions changelog.d/16827.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a race when registering via email 3pid where 2 different user ids would be created.
22 changes: 21 additions & 1 deletion synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@

logger = logging.getLogger(__name__)

USER_REGISTRATION_LOCK_NAME = "user_registration"


class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
Expand Down Expand Up @@ -415,6 +417,7 @@ def __init__(self, hs: "HomeServer"):
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self.clock = hs.get_clock()
self.password_auth_provider = hs.get_password_auth_provider()
self._registration_enabled = self.hs.config.registration.enable_registration
Expand Down Expand Up @@ -506,6 +509,23 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"An access token should not be provided on requests to /register (except if type is m.login.application_service)",
)

# Take a global lock when doing user registration to avoid races,
# for example when doing 3pid email binding.
async with self._worker_lock_handler.acquire_lock(
USER_REGISTRATION_LOCK_NAME, USER_REGISTRATION_LOCK_NAME
):
return await self._do_user_register(
desired_username, client_addr, body, should_issue_refresh_token, request
)

async def _do_user_register(
self,
desired_username: Optional[str],
address: str,
body: JsonDict,
should_issue_refresh_token: bool,
request: SynapseRequest,
) -> Tuple[int, JsonDict]:
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
Expand Down Expand Up @@ -700,7 +720,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
guest_access_token=guest_access_token,
threepid=threepid,
default_display_name=display_name,
address=client_addr,
address=address,
user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
Expand Down
92 changes: 91 additions & 1 deletion tests/rest/client/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
#
import datetime
import os
from typing import Any, Dict, List, Tuple
import re
from typing import Any, Dict, List, Optional, Tuple

import pkg_resources

Expand All @@ -40,6 +41,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import ThreadedMemoryReactorClock
from tests.unittest import override_config


Expand Down Expand Up @@ -1246,3 +1248,91 @@ def test_GET_ratelimiting(self) -> None:
f"{self.url}?token={token}",
)
self.assertEqual(channel.code, 200, msg=channel.result)


class EmailRegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]

def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
hs = super().make_homeserver(reactor, clock)

async def send_email(
email_address: str,
subject: str,
app_name: str,
html: str,
text: str,
additional_headers: Optional[Dict[str, str]] = None,
) -> None:
self.email_attempts.append(text)

self.email_attempts: List[str] = []
hs.get_send_email_handler().send_email = send_email # type: ignore[method-assign]
return hs

@unittest.override_config(
{
"public_baseurl": "https://test_server",
"registrations_require_3pid": ["email"],
"disable_msisdn_registration": True,
"email": {
"smtp_host": "mail_server",
"smtp_port": 2525,
"notif_from": "sender@host",
},
}
)
def test_email_3pid_registration_race(self) -> None:
channel = self.make_request("POST", b"register", {"password": "password"})
session = channel.json_body["session"]

# request a token to be sent by email for validation
channel = self.make_request(
"POST",
b"register/email/requestToken",
{
"client_secret": "client_secret",
"email": "email@email",
"send_attempt": 1,
},
)
sid = channel.json_body["sid"]

email_text = self.email_attempts[0]
match = re.search("https://test_server(.*)", email_text)
assert match is not None
validation_url = match.group(1)

# "Click" the link in the email to validate the adress
self.make_request("GET", validation_url.encode("utf-8"))

# launch 2 simultaneous register request, only one account
# should be created after that.
register_content = {
"auth": {
"session": session,
"threepid_creds": {
"client_secret": "client_secret",
"sid": sid,
},
"type": "m.login.email.identity",
},
"password": "password",
}
register1_channel = self.make_request(
"POST", b"register", register_content, await_result=False
)
register2_channel = self.make_request(
"POST", b"register", register_content, await_result=False
)
while (
not register1_channel.is_finished() or not register2_channel.is_finished()
):
self.pump()

self.assertEqual(
register1_channel.json_body["user_id"],
register2_channel.json_body["user_id"],
)

0 comments on commit c2cfa8b

Please sign in to comment.