diff --git a/redis/_cache.py b/redis/_cache.py index f8cc7d21bf..7acfdde3e7 100644 --- a/redis/_cache.py +++ b/redis/_cache.py @@ -1,6 +1,7 @@ import copy import random import time +from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict from enum import Enum from typing import List @@ -160,7 +161,38 @@ class EvictionPolicy(Enum): RANDOM = "random" -class _LocalCache: +class AbstractCache(ABC): + """ + An abstract base class for client caching implementations. + If you want to implement your own cache you must support these methods. + """ + + @abstractmethod + def set(self, command: str, response: ResponseT, keys_in_command: List[KeyT]): + pass + + @abstractmethod + def get(self, command: str) -> ResponseT: + pass + + @abstractmethod + def delete_command(self, command: str): + pass + + @abstractmethod + def delete_many(self, commands): + pass + + @abstractmethod + def flush(self): + pass + + @abstractmethod + def invalidate_key(self, key: KeyT): + pass + + +class _LocalCache(AbstractCache): """ A caching mechanism for storing redis commands and their responses. @@ -180,7 +212,7 @@ class _LocalCache: def __init__( self, - max_size: int = 100, + max_size: int = 10000, ttl: int = 0, eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, **kwargs, @@ -224,12 +256,12 @@ def get(self, command: str) -> ResponseT: """ if command in self.cache: if self._is_expired(command): - self.delete(command) + self.delete_command(command) return self._update_access(command) return copy.deepcopy(self.cache[command]["response"]) - def delete(self, command: str): + def delete_command(self, command: str): """ Delete a redis command and its metadata from the cache. @@ -285,7 +317,7 @@ def _update_access(self, command: str): def _evict(self): """Evict a redis command from the cache based on the eviction policy.""" if self._is_expired(self.commands_ttl_list[0]): - self.delete(self.commands_ttl_list[0]) + self.delete_command(self.commands_ttl_list[0]) elif self.eviction_policy == EvictionPolicy.LRU.value: self.cache.popitem(last=False) elif self.eviction_policy == EvictionPolicy.LFU.value: @@ -319,7 +351,7 @@ def _del_key_commands_map(self, keys: List[KeyT], command: str): for key in keys: self.key_commands_map[key].remove(command) - def invalidate(self, key: KeyT): + def invalidate_key(self, key: KeyT): """ Invalidate (delete) all redis commands associated with a specific key. @@ -330,4 +362,4 @@ def invalidate(self, key: KeyT): return commands = list(self.key_commands_map[key]) for command in commands: - self.delete(command) + self.delete_command(command) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 143d997757..88de893f5b 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -29,7 +29,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, - _LocalCache, + AbstractCache, ) from redis._parsers.helpers import ( _RedisCallbacks, @@ -238,11 +238,11 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enable: bool = False, - client_cache: Optional[_LocalCache] = None, + cache_enabled: bool = False, + client_cache: Optional[AbstractCache] = None, cache_max_size: int = 100, cache_ttl: int = 0, - cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_policy: str = DEFAULT_EVICTION_POLICY, cache_blacklist: List[str] = DEFAULT_BLACKLIST, cache_whitelist: List[str] = DEFAULT_WHITELIST, ): @@ -294,11 +294,11 @@ def __init__( "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, - "cache_enable": cache_enable, + "cache_enabled": cache_enabled, "client_cache": client_cache, "cache_max_size": cache_max_size, "cache_ttl": cache_ttl, - "cache_eviction_policy": cache_eviction_policy, + "cache_policy": cache_policy, "cache_blacklist": cache_blacklist, "cache_whitelist": cache_whitelist, } @@ -671,6 +671,33 @@ async def parse_response( return await retval if inspect.isawaitable(retval) else retval return response + def flush_cache(self): + try: + if self.connection: + self.connection.client_cache.flush() + else: + self.connection_pool.flush_cache() + except AttributeError: + pass + + def delete_command_from_cache(self, command): + try: + if self.connection: + self.connection.client_cache.delete_command(command) + else: + self.connection_pool.delete_command_from_cache(command) + except AttributeError: + pass + + def invalidate_key_from_cache(self, key): + try: + if self.connection: + self.connection.client_cache.invalidate_key(key) + else: + self.connection_pool.invalidate_key_from_cache(key) + except AttributeError: + pass + StrictRedis = Redis diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 486053e1cc..337c7bbdcc 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -22,7 +22,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, - _LocalCache, + AbstractCache, ) from redis._parsers import AsyncCommandsParser, Encoder from redis._parsers.helpers import ( @@ -273,11 +273,11 @@ def __init__( ssl_keyfile: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, - cache_enable: bool = False, - client_cache: Optional[_LocalCache] = None, + cache_enabled: bool = False, + client_cache: Optional[AbstractCache] = None, cache_max_size: int = 100, cache_ttl: int = 0, - cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_policy: str = DEFAULT_EVICTION_POLICY, cache_blacklist: List[str] = DEFAULT_BLACKLIST, cache_whitelist: List[str] = DEFAULT_WHITELIST, ) -> None: @@ -324,11 +324,11 @@ def __init__( "retry": retry, "protocol": protocol, # Client cache related kwargs - "cache_enable": cache_enable, + "cache_enabled": cache_enabled, "client_cache": client_cache, "cache_max_size": cache_max_size, "cache_ttl": cache_ttl, - "cache_eviction_policy": cache_eviction_policy, + "cache_policy": cache_policy, "cache_blacklist": cache_blacklist, "cache_whitelist": cache_whitelist, } diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 05a27879a6..77aa21f034 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -54,6 +54,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, + AbstractCache, _LocalCache, ) from .._parsers import ( @@ -157,11 +158,11 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enable: bool = False, - client_cache: Optional[_LocalCache] = None, - cache_max_size: int = 100, + cache_enabled: bool = False, + client_cache: Optional[AbstractCache] = None, + cache_max_size: int = 10000, cache_ttl: int = 0, - cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_policy: str = DEFAULT_EVICTION_POLICY, cache_blacklist: List[str] = DEFAULT_BLACKLIST, cache_whitelist: List[str] = DEFAULT_WHITELIST, ): @@ -221,8 +222,8 @@ def __init__( if p < 2 or p > 3: raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol - if cache_enable: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy) + if cache_enabled: + _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) else: _cache = None self.client_cache = client_cache if client_cache is not None else _cache @@ -699,7 +700,7 @@ def _cache_invalidation_process( self.client_cache.flush() else: for key in data[1]: - self.client_cache.invalidate(str_if_bytes(key)) + self.client_cache.invalidate_key(str_if_bytes(key)) async def _get_from_local_cache(self, command: str): """ @@ -729,15 +730,6 @@ def _add_to_local_cache( ): self.client_cache.set(command, response, keys) - def delete_from_local_cache(self, command: str): - """ - Delete the command from the local cache - """ - try: - self.client_cache.delete(command) - except AttributeError: - pass - class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1241,6 +1233,36 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry + def flush_cache(self): + connections = chain(self._available_connections, self._in_use_connections) + + for connection in connections: + try: + connection.client_cache.flush() + except AttributeError: + # cache is not enabled + pass + + def delete_command_from_cache(self, command: str): + connections = chain(self._available_connections, self._in_use_connections) + + for connection in connections: + try: + connection.client_cache.delete_command(command) + except AttributeError: + # cache is not enabled + pass + + def invalidate_key_from_cache(self, key: str): + connections = chain(self._available_connections, self._in_use_connections) + + for connection in connections: + try: + connection.client_cache.invalidate_key(key) + except AttributeError: + # cache is not enabled + pass + class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/client.py b/redis/client.py index d685145339..2d4c512699 100755 --- a/redis/client.py +++ b/redis/client.py @@ -10,7 +10,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, - _LocalCache, + AbstractCache, ) from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -209,11 +209,11 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enable: bool = False, - client_cache: Optional[_LocalCache] = None, - cache_max_size: int = 100, + cache_enabled: bool = False, + client_cache: Optional[AbstractCache] = None, + cache_max_size: int = 10000, cache_ttl: int = 0, - cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_policy: str = DEFAULT_EVICTION_POLICY, cache_blacklist: List[str] = DEFAULT_BLACKLIST, cache_whitelist: List[str] = DEFAULT_WHITELIST, ) -> None: @@ -267,11 +267,11 @@ def __init__( "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, - "cache_enable": cache_enable, + "cache_enabled": cache_enabled, "client_cache": client_cache, "cache_max_size": cache_max_size, "cache_ttl": cache_ttl, - "cache_eviction_policy": cache_eviction_policy, + "cache_policy": cache_policy, "cache_blacklist": cache_blacklist, "cache_whitelist": cache_whitelist, } @@ -592,6 +592,33 @@ def parse_response(self, connection, command_name, **options): return self.response_callbacks[command_name](response, **options) return response + def flush_cache(self): + try: + if self.connection: + self.connection.client_cache.flush() + else: + self.connection_pool.flush_cache() + except AttributeError: + pass + + def delete_command_from_cache(self, command): + try: + if self.connection: + self.connection.client_cache.delete_command(command) + else: + self.connection_pool.delete_command_from_cache(command) + except AttributeError: + pass + + def invalidate_key_from_cache(self, key): + try: + if self.connection: + self.connection.client_cache.invalidate_key(key) + else: + self.connection_pool.invalidate_key_from_cache(key) + except AttributeError: + pass + StrictRedis = Redis diff --git a/redis/cluster.py b/redis/cluster.py index e558be1689..c36665eb5c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -167,11 +167,11 @@ def parse_cluster_myshardid(resp, **options): "ssl_password", "unix_socket_path", "username", - "cache_enable", + "cache_enabled", "client_cache", "cache_max_size", "cache_ttl", - "cache_eviction_policy", + "cache_policy", "cache_blacklist", "cache_whitelist", ) diff --git a/redis/connection.py b/redis/connection.py index a09fb3949c..1f46267146 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -17,6 +17,7 @@ DEFAULT_BLACKLIST, DEFAULT_EVICTION_POLICY, DEFAULT_WHITELIST, + AbstractCache, _LocalCache, ) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser @@ -157,11 +158,11 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, - cache_enable: bool = False, - client_cache: Optional[_LocalCache] = None, - cache_max_size: int = 100, + cache_enabled: bool = False, + client_cache: Optional[AbstractCache] = None, + cache_max_size: int = 10000, cache_ttl: int = 0, - cache_eviction_policy: str = DEFAULT_EVICTION_POLICY, + cache_policy: str = DEFAULT_EVICTION_POLICY, cache_blacklist: List[str] = DEFAULT_BLACKLIST, cache_whitelist: List[str] = DEFAULT_WHITELIST, ): @@ -229,8 +230,8 @@ def __init__( # p = DEFAULT_RESP_VERSION self.protocol = p self._command_packer = self._construct_command_packer(command_packer) - if cache_enable: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_eviction_policy) + if cache_enabled: + _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) else: _cache = None self.client_cache = client_cache if client_cache is not None else _cache @@ -626,7 +627,7 @@ def _cache_invalidation_process( self.client_cache.flush() else: for key in data[1]: - self.client_cache.invalidate(str_if_bytes(key)) + self.client_cache.invalidate_key(str_if_bytes(key)) def _get_from_local_cache(self, command: str): """ @@ -656,15 +657,6 @@ def _add_to_local_cache( ): self.client_cache.set(command, response, keys) - def delete_from_local_cache(self, command: str): - """ - Delete the command from the local cache - """ - try: - self.client_cache.delete(command) - except AttributeError: - pass - class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1281,6 +1273,42 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry + def flush_cache(self): + self._checkpid() + with self._lock: + connections = chain(self._available_connections, self._in_use_connections) + + for connection in connections: + try: + connection.client_cache.flush() + except AttributeError: + # cache is not enabled + pass + + def delete_command_from_cache(self, command: str): + self._checkpid() + with self._lock: + connections = chain(self._available_connections, self._in_use_connections) + + for connection in connections: + try: + connection.client_cache.delete_command(command) + except AttributeError: + # cache is not enabled + pass + + def invalidate_key_from_cache(self, key: str): + self._checkpid() + with self._lock: + connections = chain(self._available_connections, self._in_use_connections) + + for connection in connections: + try: + connection.client_cache.invalidate_key(key) + except AttributeError: + # cache is not enabled + pass + class BlockingConnectionPool(ConnectionPool): """