Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type annotations for LruCache #8562

Merged
merged 4 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8562.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations for `LruCache`.
4 changes: 3 additions & 1 deletion synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()

self.token_cache = LruCache(10000, "token_cache")
self.token_cache = LruCache(
10000, "token_cache"
) # type: LruCache[str, Tuple[str, bool]]

self._auth_blocking = AuthBlocking(self.hs)

Expand Down
16 changes: 9 additions & 7 deletions synapse/push/push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
import re
from typing import Any, Dict, List, Optional, Pattern, Union
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union

from synapse.events import EventBase
from synapse.types import UserID
Expand Down Expand Up @@ -173,19 +173,21 @@ def _contains_display_name(self, display_name: str) -> bool:
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
r = re.escape(display_name)
r = _re_word_boundary(r)
r = re.compile(r, flags=re.IGNORECASE)
r1 = re.escape(display_name)
r1 = _re_word_boundary(r1)
r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r

return r.search(body)
return bool(r.search(body))

def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)


# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000, "regex_push_cache")
regex_cache = LruCache(
50000, "regex_push_cache"
) # type: LruCache[Tuple[str, bool, bool], Pattern]


def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
Expand All @@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r
return r.search(value)
return bool(r.search(value))
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False
Expand Down
5 changes: 3 additions & 2 deletions synapse/util/caches/deferred_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def metrics_cb():
size_callback=(lambda d: len(d)) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
) # type: LruCache[KT, VT]

self.thread = None # type: Optional[threading.Thread]

Expand Down Expand Up @@ -240,11 +240,12 @@ def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
key = cast(KT, key)
self.cache.del_multi(key)

# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
Expand Down
22 changes: 13 additions & 9 deletions synapse/util/caches/dictionary_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import logging
import threading
from collections import namedtuple
from typing import Any

from synapse.util.caches.lrucache import LruCache

Expand All @@ -38,23 +39,26 @@ def __len__(self):
return len(self.value)


class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()


class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""

def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len)
self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]

self.name = name
self.sequence = 0
self.thread = None

class Sentinel:
__slots__ = []

self.sentinel = Sentinel()

def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
Expand All @@ -76,8 +80,8 @@ def get(self, key, dict_keys=None):
Returns:
DictionaryEntry
"""
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
entry = self.cache.get(key, _Sentinel.sentinel)
if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
Expand Down
78 changes: 66 additions & 12 deletions synapse/util/caches/lrucache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,35 @@

import threading
from functools import wraps
from typing import Callable, Optional, Type, Union
from typing import (
Any,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from typing_extensions import Literal

from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache

# Function type: the type used for invalidation callbacks
FT = TypeVar("FT", bound=Callable[..., Any])

# Key and Value type for the cache
KT = TypeVar("KT")
VT = TypeVar("VT")

# a general type var, distinct from either KT or VT
T = TypeVar("T")


def enumerate_leaves(node, depth):
if depth == 0:
Expand All @@ -42,7 +65,7 @@ def __init__(self, prev_node, next_node, key, value, callbacks=set()):
self.callbacks = callbacks


class LruCache:
class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.

Expand Down Expand Up @@ -128,13 +151,13 @@ def evict():
if metrics:
metrics.inc_evictions(evicted_len)

def synchronized(f):
def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)

return inner
return cast(FT, inner)

cached_cache_len = [0]
if size_callback is not None:
Expand Down Expand Up @@ -188,8 +211,31 @@ def delete_node(node):
node.callbacks.clear()
return deleted_len

@overload
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
Comment on lines +218 to +219
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not realize you could use ... in this situation, that's helpful. 👍

) -> Optional[VT]:
...

@overload
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...

@synchronized
def cache_get(key, default=None, callbacks=[], update_metrics=True):
def cache_get(
key: KT,
default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = [],
update_metrics: bool = True,
):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
Expand All @@ -203,7 +249,7 @@ def cache_get(key, default=None, callbacks=[], update_metrics=True):
return default

@synchronized
def cache_set(key, value, callbacks=[]):
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
Expand Down Expand Up @@ -232,7 +278,7 @@ def cache_set(key, value, callbacks=[]):
evict()

@synchronized
def cache_set_default(key, value):
def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
Expand All @@ -241,8 +287,16 @@ def cache_set_default(key, value):
evict()
return value

@overload
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
...

@overload
def cache_pop(key: KT, default: T) -> Union[T, VT]:
...

@synchronized
def cache_pop(key, default=None):
def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
Expand All @@ -252,18 +306,18 @@ def cache_pop(key, default=None):
return default

@synchronized
def cache_del_multi(key):
def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
for leaf in enumerate_leaves(popped, keylen - len(key)):
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)

@synchronized
def cache_clear():
def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
Expand All @@ -274,7 +328,7 @@ def cache_clear():
cached_cache_len[0] = 0

@synchronized
def cache_contains(key):
def cache_contains(key: KT) -> bool:
return key in cache

self.sentinel = object()
Expand Down