Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix some error cases in the caching layer. #5749

Merged
merged 1 commit into from
Jul 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5749.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix some error cases in the caching layer.
74 changes: 42 additions & 32 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import threading
from collections import namedtuple

import six
from six import itervalues, string_types
from six import itervalues

from twisted.internet import defer

Expand All @@ -30,7 +29,6 @@
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii

from . import register_cache

Expand Down Expand Up @@ -108,7 +106,7 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
update_metrics (bool): whether to update the cache hit rate metrics

Returns:
Either a Deferred or the raw result
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
Expand All @@ -132,40 +130,63 @@ def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
return default

def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")

callbacks = [callback] if callback else []
self.check_thread()
entry = CacheEntry(deferred=value, callbacks=callbacks)
observable = ObservableDeferred(value, consumeErrors=True)
observer = defer.maybeDeferred(observable.observe)
entry = CacheEntry(deferred=observable, callbacks=callbacks)

existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()

self._pending_deferred_cache[key] = entry

def shuffle(result):
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.

Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True

# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

return False

def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry

# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
return result

entry.deferred.addCallback(shuffle)
def eb(_fail):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we do this before? I'm a bit confused where this is coming from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we didn't, and that was part of the problem. We would end up with a failed lookup stuck in _pending_deferred_cache, though that was mitigated by the call to cache.invalidate from the onErr callback in wrapped, which would clear it, but only if the exception was thrown asynchronously.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, ok, thanks

compare_and_pop()
entry.invalidate()

# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable

def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
Expand Down Expand Up @@ -398,20 +419,10 @@ def onErr(f):

ret.addErrback(onErr)

# If our cache_key is a string on py2, try to convert to ascii
# to save a bit of space in large caches. Py3 does this
# internally automatically.
if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)

result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()

if isinstance(observer, defer.Deferred):
return make_deferred_yieldable(observer)
else:
return observer
return make_deferred_yieldable(observer)

if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
Expand Down Expand Up @@ -527,16 +538,15 @@ def arg_to_cache_key(arg):
missing.add(arg)

if missing:
# we need an observable deferred for each entry in the list,
# we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)
cache.set(key, deferred, callback=invalidate_callback)

def complete_all(res):
# the wrapped function has completed. It returns a
Expand Down
90 changes: 87 additions & 3 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached

from tests import unittest

Expand Down Expand Up @@ -55,12 +56,15 @@ def record_callback(idx):
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))

# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())

# let one of the lookups complete
d2.callback("result2")

# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")

# now do the invalidation
Expand Down Expand Up @@ -146,6 +150,28 @@ def fn(self, arg1, arg2):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""

class Cls(object):
@cached()
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")

obj = Cls()

# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)

# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)

# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)

def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when
using the cache."""
Expand Down Expand Up @@ -222,6 +248,9 @@ def do_lookup():

self.assertEqual(LoggingContext.current_context(), c1)

# the cache should now be empty
self.assertEqual(len(obj.fn.cache.cache), 0)

obj = Cls()

# set off a deferred which will do a cache lookup
Expand Down Expand Up @@ -268,6 +297,61 @@ def fn(self, arg1, arg2=2, arg3=3):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

def test_cache_iterable(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached(iterable=True)
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)

obj = Cls()

obj.mock.return_value = ["spam", "eggs"]
r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()

# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)

r = obj.fn(1, 2)
self.assertEqual(r, ["spam", "eggs"])
r = obj.fn(1, 3)
self.assertEqual(r, ["chips"])
obj.mock.assert_not_called()

def test_cache_iterable_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work
"""

class Cls(object):
@descriptors.cached(iterable=True)
def fn(self, arg1):
raise SynapseError(100, "mai spoon iz too big!!1")

obj = Cls()

# this should fail immediately
d = obj.fn(1)
self.failureResultOf(d, SynapseError)

# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)

# and a second call should result in a second exception
d = obj.fn(1)
self.failureResultOf(d, SynapseError)


class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
Expand Down