Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client side caching invalidations (standalone) #3089

Merged
merged 13 commits into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from redis import asyncio # noqa
from redis.backoff import default_backoff
from redis.cache import _LocalChace
from redis.client import Redis, StrictRedis
from redis.cluster import RedisCluster
from redis.connection import (
Expand Down Expand Up @@ -62,7 +61,6 @@ def int_or_str(value):
VERSION = tuple([99, 99, 99])

__all__ = [
"_LocalChace",
"AuthenticationError",
"AuthenticationWrongNumberOfArgsError",
"BlockingConnectionPool",
Expand Down
66 changes: 44 additions & 22 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from .base import _AsyncRESPBase, _RESPBase
from .socket import SERVER_CLOSED_CONNECTION_ERROR

_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]


class _RESP3Parser(_RESPBase):
"""RESP3 protocol implementation"""

def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidations_push_handler_func = None

def handle_push_response(self, response):
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response
Expand Down Expand Up @@ -114,30 +117,40 @@ def _read_response(self, disable_decoding=False, push_request=False):
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
self.handle_push_response(response, disable_decoding, push_request)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
def handle_push_response(self, response, disable_decoding, push_request):
if response[0] in _INVALIDATION_MESSAGE:
res = self.invalidation_push_handler_func(response)
else:
res = self.pubsub_push_handler_func(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
Comment on lines +128 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it be better to test response[0] for pubsub as well (think its "message"?) and error out if its another type not supported.

if willing to force python 3.10, can use a switch/case/default stateent?

don't understand what pust_request is yet.


def set_pubsub_push_handler(self, pubsub_push_handler_func):
chayim marked this conversation as resolved.
Show resolved Hide resolved
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidations_push_handler_func):
chayim marked this conversation as resolved.
Show resolved Hide resolved
self.invalidation_push_handler_func = invalidations_push_handler_func


class _AsyncRESP3Parser(_AsyncRESPBase):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.push_handler_func = self.handle_push_response
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidations_push_handler_func = None

def handle_push_response(self, response):
def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response
Expand Down Expand Up @@ -246,19 +259,28 @@ async def _read_response(
)
for _ in range(int(response))
]
res = self.push_handler_func(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
await self.handle_push_response(response, disable_decoding, push_request)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)
return response

def set_push_handler(self, push_handler_func):
self.push_handler_func = push_handler_func
async def handle_push_response(self, response, disable_decoding, push_request):
if response[0] in _INVALIDATION_MESSAGE:
res = self.invalidation_push_handler_func(response)
else:
res = self.pubsub_push_handler_func(response)
if not push_request:
return await self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
return res
Comment on lines +270 to +280
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confused why this logic seems to be duplicated from above?


def set_pubsub_push_handler(self, pubsub_push_handler_func):
chayim marked this conversation as resolved.
Show resolved Hide resolved
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidations_push_handler_func):
chayim marked this conversation as resolved.
Show resolved Hide resolved
self.invalidation_push_handler_func = invalidations_push_handler_func
131 changes: 113 additions & 18 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
)
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.cache import (
DEFAULT_BLACKLIST,
DEFAULT_EVICTION_POLICY,
DEFAULT_WHITELIST,
_LocalCache,
)
from redis.client import (
EMPTY_RESPONSE,
NEVER_DECODE,
Expand All @@ -60,7 +66,7 @@
TimeoutError,
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
from redis.utils import (
HIREDIS_AVAILABLE,
_set_info_logger,
Expand Down Expand Up @@ -231,6 +237,13 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
cache_enable: bool = False,
client_cache: Optional[_LocalCache] = None,
cache_max_size: int = 100,
cache_ttl: int = 0,
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
cache_whitelist: List[str] = DEFAULT_WHITELIST,
):
"""
Initialize a new Redis client.
Expand Down Expand Up @@ -336,6 +349,16 @@ def __init__(
# on a set of redis commands
self._single_conn_lock = asyncio.Lock()

self.client_cache = client_cache
if cache_enable:
self.client_cache = _LocalCache(
cache_max_size, cache_ttl, cache_eviction_policy
)
if self.client_cache is not None:
self.cache_blacklist = cache_blacklist
self.cache_whitelist = cache_whitelist
self.client_cache_initialized = False

def __repr__(self):
return (
f"<{self.__class__.__module__}.{self.__class__.__name__}"
Expand All @@ -350,6 +373,10 @@ async def initialize(self: _RedisT) -> _RedisT:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")
if self.client_cache is not None:
self.connection._parser.set_invalidation_push_handler(
self._cache_invalidation_process
)
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down Expand Up @@ -568,6 +595,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
close_connection_pool is None and self.auto_close_connection_pool
):
await self.connection_pool.disconnect()
if self.client_cache:
self.client_cache.flush()

@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
Expand Down Expand Up @@ -596,29 +625,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
):
raise error

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
"""
Invalidate (delete) all redis commands associated with a specific key.
`data` is a list of strings, where the first string is the invalidation message
and the second string is the list of keys to invalidate.
(if the list of keys is None, then all keys are invalidated)
"""
if data[1] is not None:
for key in data[1]:
self.client_cache.invalidate(str_if_bytes(key))
else:
self.client_cache.flush()

async def _get_from_local_cache(self, command: str):
"""
If the command is in the local cache, return the response
"""
if (
self.client_cache is None
or command[0] in self.cache_blacklist
or command[0] not in self.cache_whitelist
):
return None
while not self.connection._is_socket_empty():
await self.connection.read_response(push_request=True)
return self.client_cache.get(command)

def _add_to_local_cache(
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
):
"""
Add the command and response to the local cache if the command
is allowed to be cached
"""
if (
self.client_cache is not None
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
):
self.client_cache.set(command, response, keys)

def delete_from_local_cache(self, command: str):
"""
Delete the command from the local cache
"""
try:
self.client_cache.delete(command)
except AttributeError:
pass

# COMMAND EXECUTION AND PROTOCOL PARSING
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
options.pop("keys", None) # the keys are used only for client side caching
pool = self.connection_pool
command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)
keys = options.pop("keys", None) # keys are used only for client side caching
response_from_cache = await self._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)

if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
return await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
await self._single_conn_lock.acquire()
try:
if self.client_cache is not None and not self.client_cache_initialized:
await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
),
lambda error: self._disconnect_raise(conn, error),
)
self.client_cache_initialized = True
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
self._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down Expand Up @@ -866,7 +961,7 @@ async def connect(self):
else:
await self.connection.connect()
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_push_handler(self.push_handler_func)
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)

async def _disconnect_raise_connect(self, conn, error):
"""
Expand Down
4 changes: 4 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
output.append(SYM_EMPTY.join(pieces))
return output

def _is_socket_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()


class Connection(AbstractConnection):
"Manages TCP communication to and from a Redis server"
Expand Down
18 changes: 10 additions & 8 deletions redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class EvictionPolicy(Enum):
RANDOM = "random"


class _LocalChace:
class _LocalCache:
"""
A caching mechanism for storing redis commands and their responses.

Expand Down Expand Up @@ -220,6 +220,7 @@ def get(self, command: str) -> ResponseT:
if command in self.cache:
if self._is_expired(command):
self.delete(command)
return
self._update_access(command)
return self.cache[command]["response"]

Expand Down Expand Up @@ -266,28 +267,28 @@ def _update_access(self, command: str):
Args:
command (str): The redis command.
"""
if self.eviction_policy == EvictionPolicy.LRU:
if self.eviction_policy == EvictionPolicy.LRU.value:
dvora-h marked this conversation as resolved.
Show resolved Hide resolved
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.LFU:
elif self.eviction_policy == EvictionPolicy.LFU.value:
self.cache[command]["access_count"] = (
self.cache.get(command, {}).get("access_count", 0) + 1
)
self.cache.move_to_end(command)
elif self.eviction_policy == EvictionPolicy.RANDOM:
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
pass # Random eviction doesn't require updates

def _evict(self):
"""Evict a redis command from the cache based on the eviction policy."""
if self._is_expired(self.commands_ttl_list[0]):
self.delete(self.commands_ttl_list[0])
elif self.eviction_policy == EvictionPolicy.LRU:
elif self.eviction_policy == EvictionPolicy.LRU.value:
self.cache.popitem(last=False)
elif self.eviction_policy == EvictionPolicy.LFU:
elif self.eviction_policy == EvictionPolicy.LFU.value:
min_access_command = min(
self.cache, key=lambda k: self.cache[k].get("access_count", 0)
)
self.cache.pop(min_access_command)
elif self.eviction_policy == EvictionPolicy.RANDOM:
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
random_command = random.choice(list(self.cache.keys()))
self.cache.pop(random_command)

Expand Down Expand Up @@ -322,5 +323,6 @@ def invalidate(self, key: KeyT):
"""
if key not in self.key_commands_map:
return
for command in self.key_commands_map[key]:
commands = list(self.key_commands_map[key])
for command in commands:
self.delete(command)
Loading
Loading