Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type annotations to SimpleHttpClient #8372

Merged
merged 3 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8372.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `SimpleHttpClient`.
2 changes: 1 addition & 1 deletion synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def _get() -> Optional[JsonDict]:
urllib.parse.quote(protocol),
)
try:
info = await self.get_json(uri, {})
info = await self.get_json(uri)

if not _is_valid_3pe_metadata(info):
logger.warning(
Expand Down
187 changes: 131 additions & 56 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@
import logging
import urllib
from io import BytesIO
from typing import (
Any,
BinaryIO,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

import treq
from canonicaljson import encode_canonical_json
Expand All @@ -37,6 +49,7 @@
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse

from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import (
Expand All @@ -57,6 +70,19 @@
"synapse_http_client_responses", "", ["method", "code"]
)

# the type of the headers list, to be passed to the t.w.h.Headers.
# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
# we simplify.
RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]

# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
RawHeaderValue = Sequence[Union[str, bytes]]

# the type of the query params, to be passed into `urlencode`
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]


def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
Expand Down Expand Up @@ -285,13 +311,26 @@ def __getattr__(_self, attr):
ip_blacklist=self._ip_blacklist,
)

async def request(self, method, uri, data=None, headers=None):
async def request(
self,
method: str,
uri: str,
data: Optional[bytes] = None,
headers: Optional[Headers] = None,
) -> IResponse:
"""
Args:
method (str): HTTP method to use.
uri (str): URI to query.
data (bytes): Data to send in the request body, if applicable.
headers (t.w.http_headers.Headers): Request headers.
method: HTTP method to use.
uri: URI to query.
data: Data to send in the request body, if applicable.
headers: Request headers.

Returns:
Response object, once the headers have been read.

Raises:
RequestTimedOutError if the request times out before the headers are read

"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
Expand Down Expand Up @@ -324,6 +363,8 @@ async def request(self, method, uri, data=None, headers=None):
headers=headers,
**self._extra_treq_args
)
# we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
request_deferred,
60,
Expand Down Expand Up @@ -353,18 +394,26 @@ async def request(self, method, uri, data=None, headers=None):
set_tag("error_reason", e.args[0])
raise

async def post_urlencoded_get_json(self, uri, args={}, headers=None):
async def post_urlencoded_get_json(
self,
uri: str,
args: Mapping[str, Union[str, List[str]]] = {},
headers: Optional[RawHeaders] = None,
) -> Any:
"""
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: uri to query
args: parameters to be url-encoded in the body
headers: a map from header name to a list of values for that header

Returns:
object: parsed json
parsed json

Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException: On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand Down Expand Up @@ -398,19 +447,24 @@ async def post_urlencoded_get_json(self, uri, args={}, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

async def post_json_get_json(self, uri, post_json, headers=None):
async def post_json_get_json(
self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
) -> Any:
"""

Args:
uri (str):
post_json (object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: URI to query.
post_json: request body, to be encoded as json
headers: a map from header name to a list of values for that header

Returns:
object: parsed json
parsed json

Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException: On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand Down Expand Up @@ -440,21 +494,22 @@ async def post_json_get_json(self, uri, post_json, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

async def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
async def get_json(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.

Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: The URI to request, not including query parameters
args: A dictionary used to create query string
headers: a map from header name to a list of values for that header
Returns:
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand All @@ -466,22 +521,27 @@ async def get_json(self, uri, args={}, headers=None):
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))

async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
async def put_json(
self,
uri: str,
json_body: Any,
args: QueryParams = {},
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.

Args:
uri (str): The URI to request, not including query parameters
json_body (dict): The JSON to put in the HTTP body,
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: The URI to request, not including query parameters
json_body: The JSON to put in the HTTP body,
args: A dictionary used to create query strings
headers: a map from header name to a list of values for that header
Returns:
Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException On a non-2xx HTTP response.

ValueError: if the response was not JSON
Expand Down Expand Up @@ -513,21 +573,23 @@ async def put_json(self, uri, json_body, args={}, headers=None):
response.code, response.phrase.decode("ascii", errors="replace"), body
)

async def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
async def get_raw(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
) -> bytes:
"""Gets raw text from the given URI.

Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
uri: The URI to request, not including query parameters
args: A dictionary used to create query strings
headers: a map from header name to a list of values for that header
Returns:
Succeeds when we get *any* 2xx HTTP response, with the
Succeeds when we get a 2xx HTTP response, with the
HTTP body as bytes.
Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
Expand All @@ -552,16 +614,29 @@ async def get_raw(self, uri, args={}, headers=None):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.

async def get_file(self, url, output_stream, max_size=None, headers=None):
async def get_file(
self,
url: str,
output_stream: BinaryIO,
max_size: Optional[int] = None,
headers: Optional[RawHeaders] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
url: The URL to GET
output_stream: File to write the response body to.
headers: A map from header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
A tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.

Raises:
RequestTimedOutException: if there is a timeout before the response headers
are received. Note there is currently no timeout on reading the response
body.

SynapseError: if the response is not a 2xx, the remote file is too large, or
another exception happens during the download.
"""

actual_headers = {b"User-Agent": [self.user_agent]}
Expand Down
14 changes: 10 additions & 4 deletions synapse/rest/media/v1/preview_url_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e

async def _download_url(self, url, user):
async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
Expand All @@ -460,7 +460,7 @@ async def _download_url(self, url, user):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)

# If this URL can be accessed via oEmbed, use that instead.
url_to_download = url
url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
Expand Down Expand Up @@ -520,9 +520,15 @@ async def _download_url(self, url, user):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
etag = headers["ETag"][0] if "ETag" in headers else None
etag = (
headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
)
Comment on lines +523 to +525
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nominally a bugfix ("Etag" would never be in the headers dict), but since we don't do anything with the resulting value other than stick it in the database, I don't think it's worth calling out separately.

else:
html_bytes = oembed_result.html.encode("utf-8") # type: ignore
# we can only get here if we did an oembed request and have an oembed_result.html
assert oembed_result.html is not None
assert oembed_url is not None

html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
Expand Down