Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to response cache. #8507

Merged
merged 3 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8507.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to various parts of the code base.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ files =
synapse/types.py,
synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/response_cache.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,
Expand Down
4 changes: 2 additions & 2 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple

from prometheus_client import Counter

Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self, hs):

self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
)
) # type: ResponseCache[Tuple[str, str]]

async def query_user(self, service, user_id):
if service.url is None:
Expand Down
8 changes: 5 additions & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,20 @@ def __init__(self, hs):
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
hs, "fed_txn_handler", timeout_ms=30000
)
) # type: ResponseCache[Tuple[str, str]]

self.transaction_actions = TransactionActions(self.store)

self.registry = hs.get_federation_registry()

# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
self._state_resp_cache = ResponseCache(
hs, "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
hs, "state_ids_resp", timeout_ms=30000
)
) # type: ResponseCache[Tuple[str, str]]

self._federation_metrics_domains = (
hs.get_config().federation.federation_metrics_domains
Expand Down
10 changes: 6 additions & 4 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional, Tuple

from twisted.internet import defer

Expand Down Expand Up @@ -47,12 +47,14 @@ def __init__(self, hs: "HomeServer"):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
self.snapshot_cache = ResponseCache(
hs, "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state

def snapshot_all_rooms(
async def snapshot_all_rooms(
self,
user_id: str,
pagin_config: PaginationConfig,
Expand Down Expand Up @@ -84,7 +86,7 @@ def snapshot_all_rooms(
include_archived,
)

return self.snapshot_cache.wrap(
return await self.snapshot_cache.wrap(
key,
self._snapshot_all_rooms,
user_id,
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, hs: "HomeServer"):
# subsequent requests
self._upgrade_response_cache = ResponseCache(
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
)
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid

self.third_party_event_rules = hs.get_third_party_event_rules()
Expand Down
4 changes: 3 additions & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def __init__(self, hs: "HomeServer"):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs, "sync")
self.response_cache = ResponseCache(
hs, "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.storage = hs.get_storage()
Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, hs):
if self.CACHE:
self.response_cache = ResponseCache(
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
) # type: ResponseCache[str]

# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
Expand Down
50 changes: 28 additions & 22 deletions synapse/util/caches/response_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar

from twisted.internet import defer

from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

T = TypeVar("T")


class ResponseCache:
class ResponseCache(Generic[T]):
"""
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
while the response is still being computed, that original response will be
used rather than trying to compute a new response.
"""

def __init__(self, hs, name, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]

self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0

self._name = name
self._metrics = register_cache("response_cache", name, self, resizable=False)

def size(self):
def size(self) -> int:
return len(self.pending_result_cache)

def __len__(self):
def __len__(self) -> int:
return self.size()

def get(self, key):
def get(self, key: T) -> Optional[defer.Deferred]:
"""Look up the given key.

Can return either a new Deferred (which also doesn't follow the synapse
Expand All @@ -58,12 +65,11 @@ def get(self, key):
from an absent cache entry.

Args:
key (hashable):
key: key to get/set in the cache

Returns:
twisted.internet.defer.Deferred|None|E: None if there is no entry
for this key; otherwise either a deferred result or the result
itself.
None if there is no entry for this key; otherwise a deferred which
resolves to the result.
"""
result = self.pending_result_cache.get(key)
if result is not None:
Expand All @@ -73,7 +79,7 @@ def get(self, key):
self._metrics.inc_misses()
return None

def set(self, key, deferred):
def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.

*deferred* should run its callbacks in the sentinel logcontext (ie,
Expand All @@ -85,12 +91,11 @@ def set(self, key, deferred):
result. You will probably want to make_deferred_yieldable the result.

Args:
key (hashable):
deferred (twisted.internet.defer.Deferred[T):
key: key to get/set in the cache
deferred: The deferred which resolves to the result.

Returns:
twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
result.
A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
Expand All @@ -107,7 +112,9 @@ def remove(r):
result.addBoth(remove)
return result.observe()

def wrap(self, key, callback, *args, **kwargs):
def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred:
"""Wrap together a *get* and *set* call, taking care of logcontexts

First looks up the key in the cache, and if it is present makes it
Expand All @@ -118,29 +125,28 @@ def wrap(self, key, callback, *args, **kwargs):

Example usage:

@defer.inlineCallbacks
def handle_request(request):
async def handle_request(request):
# etc
return result

result = yield response_cache.wrap(
result = await response_cache.wrap(
key,
handle_request,
request,
)

Args:
key (hashable): key to get/set in the cache
key: key to get/set in the cache

callback (callable): function to call if the key is not found in
callback: function to call if the key is not found in
the cache

*args: positional parameters to pass to the callback, if it is used

**kwargs: named parameters to pass to the callback, if it is used

Returns:
twisted.internet.defer.Deferred: yieldable result
Deferred which resolves to the result
"""
result = self.get(key)
if not result:
Expand Down