Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cancel the processing of key query requests when they time out. #13680

Merged
merged 8 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13680.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cancel the processing of key query requests when they time out.
5 changes: 5 additions & 0 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
trace,
)
from synapse.types import Requester, create_requester
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -118,6 +119,7 @@ async def check_user_in_room(
errcode=Codes.NOT_JOINED,
)

@cancellable
async def get_user_by_req(
self,
request: SynapseRequest,
Expand Down Expand Up @@ -166,6 +168,7 @@ async def get_user_by_req(
parent_span.set_tag("appservice_id", requester.app_service.id)
return requester

@cancellable
async def _wrapped_get_user_by_req(
self,
request: SynapseRequest,
Expand Down Expand Up @@ -281,6 +284,7 @@ async def validate_appservice_can_control_user_id(
403, "Application service has not registered this user (%s)" % user_id
)

@cancellable
async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
"""
Given a request, reads the request parameters to determine:
Expand Down Expand Up @@ -523,6 +527,7 @@ def has_access_token(request: Request) -> bool:
return bool(query_params) or bool(auth_headers)

@staticmethod
@cancellable
def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request.

Expand Down
3 changes: 3 additions & 0 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination

Expand Down Expand Up @@ -124,6 +125,7 @@ async def get_device(self, user_id: str, device_id: str) -> JsonDict:

return device

@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: Collection[str], from_token: StreamToken
) -> Collection[str]:
Expand Down Expand Up @@ -163,6 +165,7 @@ async def get_device_changes_in_shared_rooms(

@trace
@measure_func("device.get_user_ids_changed")
@cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
Expand Down
40 changes: 24 additions & 16 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.async_helpers import Linearizer, delay_cancellation
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination

if TYPE_CHECKING:
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(self, hs: "HomeServer"):
)

@trace
@cancellable
async def query_devices(
self,
query_body: JsonDict,
Expand Down Expand Up @@ -208,22 +210,26 @@ async def query_devices(
r[user_id] = remote_queries[user_id]

# Now fetch any devices that we don't have in our cache
# TODO It might make sense to propagate cancellations into the
# deferreds which are querying remote homeservers.
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._query_devices_for_destination,
results,
cross_signing_keys,
failures,
destination,
queries,
timeout,
)
for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
delay_cancellation(
defer.gatherResults(
[
run_in_background(
self._query_devices_for_destination,
results,
cross_signing_keys,
failures,
destination,
queries,
timeout,
)
for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
)

ret = {"device_keys": results, "failures": failures}
Expand Down Expand Up @@ -347,6 +353,7 @@ async def _query_devices_for_destination(

return

@cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
Expand Down Expand Up @@ -393,6 +400,7 @@ async def get_cross_signing_keys_from_cache(
}

@trace
@cancellable
async def query_local_devices(
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
Expand Down
6 changes: 4 additions & 2 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken

from ._base import client_patterns, interactive_auth_handler
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -156,6 +156,7 @@ def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()

@cancellable
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(self, hs: "HomeServer"):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main

@cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)

Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -229,6 +230,7 @@ async def get_state_for_events(

@trace
@tag_args
@cancellable
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -350,6 +352,7 @@ def get_state_for_groups(

@trace
@tag_args
@cancellable
async def get_state_group_for_events(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -398,6 +401,7 @@ async def store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)

@cancellable
async def get_current_state_ids(
self,
room_id: str,
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr

Expand Down Expand Up @@ -668,6 +669,7 @@ def get_device_stream_token(self) -> int:
...

@trace
@cancellable
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
Expand Down Expand Up @@ -743,6 +745,7 @@ def get_cached_device_list_changes(

return self._device_list_stream_cache.get_all_entities_changed(from_key)

@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
Expand Down Expand Up @@ -1221,6 +1224,7 @@ async def _get_min_device_lists_changes_in_room(self) -> int:
desc="get_min_device_lists_changes_in_room",
)

@cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
) -> Optional[Set[str]]:
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,6 +136,7 @@ async def get_e2e_device_keys_for_federation_query(
return now_stream_id, []

@trace
@cancellable
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[str, JsonDict]]:
Expand Down Expand Up @@ -197,6 +199,7 @@ async def get_e2e_device_keys_and_signatures(
...

@trace
@cancellable
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
Expand Down Expand Up @@ -887,6 +890,7 @@ def _get_e2e_cross_signing_signatures_txn(

return keys

@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
Expand All @@ -902,7 +906,6 @@ async def get_e2e_cross_signing_keys_bulk(
keys were not found, either their user ID will not be in the dict,
or their user ID will map to None.
"""

result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)

if from_user_id:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -976,6 +977,7 @@ def _get_min_depth_interaction(

return int(min_depth) if min_depth is not None else None

@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

Expand Down Expand Up @@ -339,6 +340,7 @@ async def get_event(
) -> Optional[EventBase]:
...

@cancellable
async def get_event(
self,
event_id: str,
Expand Down Expand Up @@ -433,6 +435,7 @@ async def get_events(

@trace
@tag_args
@cancellable
async def get_events_as_list(
self,
event_ids: Collection[str],
Expand Down Expand Up @@ -584,6 +587,7 @@ async def get_events_as_list(

return events

@cancellable
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, EventCacheEntry]:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

Expand Down Expand Up @@ -771,6 +772,7 @@ def _get_users_server_still_shares_room_with_txn(
_get_users_server_still_shares_room_with_txn,
)

@cancellable
async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
) -> FrozenSet[str]:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -281,6 +282,7 @@ def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
)

# FIXME: how should this be cached?
@cancellable
async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -597,6 +598,7 @@ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:

return ret, key

@cancellable
async def get_membership_changes_for_user(
self,
user_id: str,
Expand Down
3 changes: 3 additions & 0 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.cancellation import cancellable

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -156,6 +157,7 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
"get_state_group_delta", _get_state_group_delta_txn
)

@cancellable
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
Expand Down Expand Up @@ -235,6 +237,7 @@ def _get_state_for_group_using_cache(

return state_filter.filter_state(state_dict_ids), not missing_types

@cancellable
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
Expand Down
Loading