Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make the http server handle coroutine-making REST servlets #5475

Merged
merged 11 commits into from
Jun 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5475.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Synapse can now handle RestServlets that return coroutines.
77 changes: 41 additions & 36 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import cgi
import collections
import http.client
import logging

from six import PY3
from six.moves import http_client, urllib
import types
import urllib
from io import BytesIO

from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json

Expand All @@ -41,11 +42,6 @@
from synapse.util.caches import intern_dict
from synapse.util.logcontext import preserve_fn

if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
Expand Down Expand Up @@ -75,10 +71,9 @@ def wrap_json_request_handler(h):
deferred fails with any other type of error we send a 500 reponse.
"""

@defer.inlineCallbacks
def wrapped_request_handler(self, request):
async def wrapped_request_handler(self, request):
try:
yield h(self, request)
await h(self, request)
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
Expand Down Expand Up @@ -142,10 +137,12 @@ def wrap_html_request_handler(h):
where "request" must be a SynapseRequest.
"""

def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
return d
async def wrapped_request_handler(self, request):
try:
return await h(self, request)
except Exception:
f = failure.Failure()
return _return_html_error(f, request)

return wrap_async_request_handler(wrapped_request_handler)

Expand All @@ -171,7 +168,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
code = http_client.INTERNAL_SERVER_ERROR
code = http.client.INTERNAL_SERVER_ERROR
msg = "Internal server error"

logger.error(
Expand Down Expand Up @@ -201,10 +198,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""

@defer.inlineCallbacks
def wrapped_async_request_handler(self, request):
async def wrapped_async_request_handler(self, request):
with request.processing():
yield h(self, request)
await h(self, request)

# we need to preserve_fn here, because the synchronous render method won't yield for
# us (obviously)
Expand Down Expand Up @@ -270,12 +266,11 @@ def register_paths(self, method, path_patterns, callback):
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
self._async_render(request)
defer.ensureDeferred(self._async_render(request))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably we need to wrap this in a deferred to make it run on the reactor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the coroutine itself does not do anything until it's either awaited or wrapped in a Deferred.

return NOT_DONE_YET

@wrap_json_request_handler
@defer.inlineCallbacks
def _async_render(self, request):
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
Expand All @@ -292,26 +287,19 @@ def _async_render(self, request):
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.

def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode("ascii")).decode("utf8")

kwargs = intern_dict(
{
name: _unquote(value) if value else value
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)

callback_return = yield callback(request, **kwargs)
callback_return = callback(request, **kwargs)

# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return

if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
Expand Down Expand Up @@ -360,6 +348,23 @@ def _send_response(
)


class DirectServeResource(resource.Resource):
def render(self, request):
"""
Render the request, using an asynchronous render handler if it exists.
"""
render_callback_name = "_async_render_" + request.method.decode("ascii")

if hasattr(self, render_callback_name):
# Call the handler
callback = getattr(self, render_callback_name)
defer.ensureDeferred(callback(request))

return NOT_DONE_YET
else:
super().render(request)


def _options_handler(request):
"""Request handler for OPTIONS requests

Expand Down
35 changes: 12 additions & 23 deletions synapse/rest/consent/consent_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import jinja2
from jinja2 import TemplateNotFound

from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import finish_request, wrap_html_request_handler
from synapse.http.server import (
DirectServeResource,
finish_request,
wrap_html_request_handler,
)
from synapse.http.servlet import parse_string
from synapse.types import UserID

Expand All @@ -47,7 +47,7 @@ def compare_digest(a, b):
return a == b


class ConsentResource(Resource):
class ConsentResource(DirectServeResource):
"""A twisted Resource to display a privacy policy and gather consent to it

When accessed via GET, returns the privacy policy via a template.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, hs):
Args:
hs (synapse.server.HomeServer): homeserver
"""
Resource.__init__(self)
super().__init__()

self.hs = hs
self.store = hs.get_datastore()
Expand Down Expand Up @@ -118,18 +118,12 @@ def __init__(self, hs):

self._hmac_secret = hs.config.form_secret.encode("utf-8")

def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET

@wrap_html_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
async def _async_render_GET(self, request):
"""
Args:
request (twisted.web.http.Request):
"""

version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", required=False, default="")
userhmac = None
Expand All @@ -145,7 +139,7 @@ def _async_render_GET(self, request):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()

u = yield self.store.get_user_by_id(qualified_user_id)
u = await self.store.get_user_by_id(qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")

Expand All @@ -165,13 +159,8 @@ def _async_render_GET(self, request):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")

def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET

@wrap_html_request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
async def _async_render_POST(self, request):
"""
Args:
request (twisted.web.http.Request):
Expand All @@ -188,12 +177,12 @@ def _async_render_POST(self, request):
qualified_user_id = UserID(username, self.hs.hostname).to_string()

try:
yield self.store.user_set_consent_version(qualified_user_id, version)
await self.store.user_set_consent_version(qualified_user_id, version)
except StoreError as e:
if e.code != 404:
raise
raise NotFoundError("Unknown user")
yield self.registration_handler.post_consent_actions(qualified_user_id)
await self.registration_handler.post_consent_actions(qualified_user_id)

try:
self._render_template(request, "success.html")
Expand Down
28 changes: 10 additions & 18 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@
from io import BytesIO

from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.server import (
DirectServeResource,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_json_object_from_request

logger = logging.getLogger(__name__)


class RemoteKey(Resource):
class RemoteKey(DirectServeResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
Expand Down Expand Up @@ -94,13 +96,8 @@ def __init__(self, hs):
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist

def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET

@wrap_json_request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
query = {server.decode("ascii"): {}}
Expand All @@ -114,20 +111,15 @@ def async_render_GET(self, request):
else:
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)

yield self.query_keys(request, query, query_remote_on_cache_miss=True)

def render_POST(self, request):
self.async_render_POST(request)
return NOT_DONE_YET
await self.query_keys(request, query, query_remote_on_cache_miss=True)

@wrap_json_request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)

query = content["server_keys"]

yield self.query_keys(request, query, query_remote_on_cache_miss=True)
await self.query_keys(request, query, query_remote_on_cache_miss=True)

@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
Expand Down
21 changes: 9 additions & 12 deletions synapse/rest/media/v1/config_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,28 @@
# limitations under the License.
#

from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET

from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)


class MediaConfigResource(Resource):
class MediaConfigResource(DirectServeResource):
isLeaf = True

def __init__(self, hs):
Resource.__init__(self)
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}

def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET

@wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
yield self.auth.get_user_by_req(request)
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)

def render_OPTIONS(self, request):
Expand Down
Loading