Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve register performance #8009

Merged
merged 7 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion synapse/api/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception):
(This indicates we should return a 401 with 'result' as the body)

Attributes:
session_id: The ID of the ongoing interactive auth session.
result: the server response to the request, which should be
passed back to the client
"""

def __init__(self, result: "JsonDict"):
def __init__(self, session_id: str, result: "JsonDict"):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete"
)
self.session_id = session_id
self.result = result


Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def check_auth(

if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session.session_id)
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
)

# check auth type currently being presented
Expand Down Expand Up @@ -410,7 +410,7 @@ async def check_auth(
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
raise InteractiveAuthIncompleteError(session.session_id, ret)

async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
Expand Down
108 changes: 69 additions & 39 deletions synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
Expand Down Expand Up @@ -387,6 +388,7 @@ def __init__(self, hs):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
self._registration_enabled = self.hs.config.enable_registration

self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
Expand All @@ -412,20 +414,8 @@ async def on_POST(self, request):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)

# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
if "password" in body:
password = body.pop("password")
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(password)

# If the password is valid, hash it and store it back on the body.
# This ensures that only the hashed password is handled everywhere.
if "password_hash" in body:
raise SynapseError(400, "Unexpected property: password_hash")
body["password_hash"] = await self.auth_handler.hash(password)

# Pull out the provided username and do basic sanity checks early since
# the auth layer will store these in sessions.
desired_username = None
if "username" in body:
if not isinstance(body["username"], str) or len(body["username"]) > 512:
Expand Down Expand Up @@ -459,22 +449,35 @@ async def on_POST(self, request):
)
return 200, result # we throw for non 200 responses

# for regular registration, downcase the provided username before
# attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
# that their username is CrAzYh4cKeR if that keeps them happy)
if desired_username is not None:
desired_username = desired_username.lower()

# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")

# For regular registration, convert the provided username to lowercase
# before attempting to register it. This should mean that people who try
# to register with upper-case in their usernames don't get a nasty surprise.
#
# Note that we treat usernames case-insensitively in login, so they are
# free to carry on imagining that their username is CrAzYh4cKeR if that
# keeps them happy.
if desired_username is not None:
desired_username = desired_username.lower()

# Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)

if "initial_device_display_name" in body and "password_hash" not in body:
# Pull out the provided password and do basic sanity checks early.
#
# Note that we remove the password from the body since the auth layer
# will store the body in the session and we don't want a plaintext
# password store there.
password = body.pop("password", None)
if password is not None:
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(password)

if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
Expand All @@ -484,6 +487,7 @@ async def on_POST(self, request):

session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
Expand All @@ -492,28 +496,53 @@ async def on_POST(self, request):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
# If a password hash was previously stored we will not attempt to
# re-hash and store it.
#
# Note that if the password changes throughout the authentication
# flow this might break, but the data is meant to be consistent
# throughout the flow.
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggled for a while to understand what was going on here, and tbh the comment didn't do much to help.

I think this comment - especially the bit about what happens if the password changes mid-flow - should move down to inside the except InteractiveAuthIncompleteError block at line 533ish, and we can just say "extract the previously-hashed password from the session" here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggled with where to put this, moving it further up sounds like it would be an improvement.

Is there a particular piece of knowledge that should be documented in comments?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggled with where to put this, moving it further up sounds like it would be an improvement.

I'm mostly suggesting the code gets left where it is, and the comment moves down.

Is there a particular piece of knowledge that should be documented in comments?

I don't think so, tbh. I think this code speaks for itself, but I only realised that after I stopped trying to understand the comment, which doesn't really describe what is happening here :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-arranged and re-worded the comments in e05118b, I'm hoping that helps?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, lgtm.


# Ensure that the username is valid.
if desired_username is not None:
await self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
)

auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
)
# Check if the user-interactive authentication flows are complete, if
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but they're
# not required to provide the password again.
#
# If a password hash was not provided with a previous request and a
# password is available now, hash the provided password and store it
# for later.
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
)
raise

# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.

if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
Expand All @@ -535,12 +564,13 @@ async def on_POST(self, request):
# don't re-register the threepids
registered = False
else:
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password_hash"])
# If we have a password in this request, prefer it. Otherwise, there
# might be a password hash from an earlier request.
if password:
password_hash = await self.auth_handler.hash(password)

desired_username = params.get("username", None)
guest_access_token = params.get("guest_access_token", None)
new_password_hash = params.get("password_hash", None)

if desired_username is not None:
desired_username = desired_username.lower()
Expand Down Expand Up @@ -582,7 +612,7 @@ async def on_POST(self, request):

registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=new_password_hash,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
Expand All @@ -595,8 +625,8 @@ async def on_POST(self, request):
):
await self.store.upsert_monthly_active_user(registered_user_id)

# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
Expand Down
2 changes: 1 addition & 1 deletion tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_POST_user_valid(self):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)

@override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)

Expand Down