Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add missing type hints to synapse.http. (#11571)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Dec 14, 2021
1 parent ff6fd52 commit 33abbc3
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 51 deletions.
1 change: 1 addition & 0 deletions changelog.d/11571.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to `synapse.http`.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ disallow_untyped_defs = False
[mypy-synapse.handlers.*]
disallow_untyped_defs = True

[mypy-synapse.http.server]
disallow_untyped_defs = True

[mypy-synapse.metrics.*]
disallow_untyped_defs = True

Expand Down
6 changes: 3 additions & 3 deletions synapse/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""

def __init__(self, msg):
def __init__(self, msg: str):
super().__init__(504, msg)


ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")


def redact_uri(uri):
def redact_uri(uri: str) -> str:
"""Strips sensitive information from the uri replaces with <redacted>"""
uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
Expand All @@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
https://twistedmatrix.com/trac/ticket/6528
"""

def stopProducing(self):
def stopProducing(self) -> None:
try:
FileBodyProducer.stopProducing(self)
except task.TaskStopped:
Expand Down
12 changes: 8 additions & 4 deletions synapse/http/additional_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple

from twisted.web.server import Request

Expand All @@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling.
"""

def __init__(self, hs: "HomeServer", handler):
def __init__(
self,
hs: "HomeServer",
handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
):
"""Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has
Expand All @@ -47,7 +51,7 @@ def __init__(self, hs: "HomeServer", handler):
super().__init__()
self._handler = handler

def _async_render(self, request: Request):
async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)
return await self._handler(request)
90 changes: 53 additions & 37 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Iterable,
Iterator,
List,
NoReturn,
Optional,
Pattern,
Tuple,
Expand Down Expand Up @@ -170,7 +171,9 @@ def return_html_error(
respond_with_html(request, code, body)


def wrap_async_request_handler(h):
def wrap_async_request_handler(
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
Expand All @@ -183,7 +186,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""

async def wrapped_async_request_handler(self, request):
async def wrapped_async_request_handler(
self: "_AsyncResource", request: SynapseRequest
) -> None:
with request.processing():
await h(self, request)

Expand Down Expand Up @@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
context from the request the servlet is handling.
"""

def __init__(self, extract_context=False):
def __init__(self, extract_context: bool = False):
super().__init__()

self._extract_context = extract_context

def render(self, request):
def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET

@wrap_async_request_handler
async def _async_render_wrapper(self, request: SynapseRequest):
async def _async_render_wrapper(self, request: SynapseRequest) -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
Expand All @@ -271,7 +276,7 @@ async def _async_render_wrapper(self, request: SynapseRequest):
f = failure.Failure()
self._send_error_response(f, request)

async def _async_render(self, request: Request):
async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
Expand Down Expand Up @@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""

def __init__(self, canonical_json=False, extract_context=False):
def __init__(self, canonical_json: bool = False, extract_context: bool = False):
super().__init__(extract_context)
self.canonical_json = canonical_json

Expand All @@ -327,7 +332,7 @@ def _send_response(
request: SynapseRequest,
code: int,
response_object: Any,
):
) -> None:
"""Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
Expand Down Expand Up @@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource):

isLeaf = True

def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
def __init__(
self,
hs: "HomeServer",
canonical_json: bool = True,
extract_context: bool = False,
):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs

def register_paths(self, method, path_patterns, callback, servlet_classname):
def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
handler for that request.
Args:
method (str): GET, POST etc
method: GET, POST etc
path_patterns (Iterable[str]): A list of regular expressions to which
the request URLs are compared.
path_patterns: A list of regular expressions to which the request
URLs are compared.
callback (function): The handler for the request. Usually a Servlet
callback: The handler for the request. Usually a Servlet
servlet_classname (str): The name of the handler to be used in prometheus
servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
method = method.encode("utf-8") # method is bytes on py3
method_bytes = method.encode("utf-8")

for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self.path_regexs.setdefault(method_bytes, []).append(
_PathEntry(path_pattern, callback, servlet_classname)
)

Expand Down Expand Up @@ -427,7 +443,7 @@ def _get_handler_for_request(
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}

async def _async_render(self, request):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)

# Make sure we have an appropriate name for this handler in prometheus
Expand Down Expand Up @@ -468,7 +484,7 @@ def _send_response(
request: SynapseRequest,
code: int,
response_object: Any,
):
) -> None:
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
Expand All @@ -492,35 +508,35 @@ class StaticResource(File):
Differs from the File resource by adding clickjacking protection.
"""

def render_GET(self, request: Request):
def render_GET(self, request: Request) -> bytes:
set_clickjacking_protection_headers(request)
return super().render_GET(request)


def _unrecognised_request_handler(request):
def _unrecognised_request_handler(request: Request) -> NoReturn:
"""Request handler for unrecognised requests
This is a request handler suitable for return from
_get_handler_for_request. It actually just raises an
UnrecognizedRequestError.
Args:
request (twisted.web.http.Request):
request: Unused, but passed in to match the signature of ServletCallback.
"""
raise UnrecognizedRequestError()


class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""

def __init__(self, path):
def __init__(self, path: str):
resource.Resource.__init__(self)
self.url = path

def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
return redirectTo(self.url.encode("ascii"), request)

def getChild(self, name, request):
def getChild(self, name: str, request: Request) -> resource.Resource:
if len(name) == 0:
return self # select ourselves as the child to render
return resource.Resource.getChild(self, name, request)
Expand All @@ -529,15 +545,15 @@ def getChild(self, name, request):
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""

def render_OPTIONS(self, request):
def render_OPTIONS(self, request: Request) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")

set_cors_headers(request)

return b""

def getChildWithDefault(self, path, request):
def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
if request.method == b"OPTIONS":
return self # select ourselves as the child to render
return resource.Resource.getChildWithDefault(self, path, request)
Expand Down Expand Up @@ -649,7 +665,7 @@ def respond_with_json(
json_object: Any,
send_cors: bool = False,
canonical_json: bool = True,
):
) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
Expand Down Expand Up @@ -696,7 +712,7 @@ def respond_with_json_bytes(
code: int,
json_bytes: bytes,
send_cors: bool = False,
):
) -> Optional[int]:
"""Sends encoded JSON in response to the given request.
Args:
Expand All @@ -713,7 +729,7 @@ def respond_with_json_bytes(
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
return None

request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
Expand All @@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
):
) -> None:
"""Encodes the given JSON object on a thread and then writes it to the
request.
Expand Down Expand Up @@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)


def set_cors_headers(request: Request):
def set_cors_headers(request: Request) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
Expand All @@ -790,14 +806,14 @@ def set_cors_headers(request: Request):
)


def respond_with_html(request: Request, code: int, html: str):
def respond_with_html(request: Request, code: int, html: str) -> None:
"""
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
"""
respond_with_html_bytes(request, code, html.encode("utf-8"))


def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
"""
Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
Expand All @@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
return None

request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
Expand All @@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
finish_request(request)


def set_clickjacking_protection_headers(request: Request):
def set_clickjacking_protection_headers(request: Request) -> None:
"""
Set headers to guard against clickjacking of embedded content.
Expand All @@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
finish_request(request)


def finish_request(request: Request):
def finish_request(request: Request) -> None:
"""Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
Expand Down
3 changes: 2 additions & 1 deletion synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from twisted.web.server import Request

from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder

Expand Down Expand Up @@ -726,7 +727,7 @@ class attribute containing a pre-compiled regular expression. The automatic
into the appropriate HTTP response.
"""

def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
"""Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None)
if patterns:
Expand Down
Loading

0 comments on commit 33abbc3

Please sign in to comment.