From 1a8870fcbebdaf6c2e780618bab415f6680eb5db Mon Sep 17 00:00:00 2001 From: Vasyl Dizhak Date: Sun, 9 Jun 2024 20:03:38 +0100 Subject: [PATCH] Add support for the set functions from issue #597 Co-authored-by: Ali Rezaei --- changelog.d/730.feature | 1 + django_redis/cache.py | 68 ++++++++ django_redis/client/default.py | 282 +++++++++++++++++++++++++++++++- django_redis/client/sharded.py | 152 ++++++++++++++++- django_redis/compressors/lz4.py | 2 +- tests/test_backend.py | 157 ++++++++++++++++++ 6 files changed, 657 insertions(+), 5 deletions(-) create mode 100644 changelog.d/730.feature diff --git a/changelog.d/730.feature b/changelog.d/730.feature new file mode 100644 index 00000000..d41ae639 --- /dev/null +++ b/changelog.d/730.feature @@ -0,0 +1 @@ +Support for sets and support basic operations, sadd, scard, sdiff, sdiffstore, sinter, sinterstore, smismember, sismember, smembers, smove, spop, srandmember, srem, sscan, sscan_iter, sunion, sunionstore \ No newline at end of file diff --git a/django_redis/cache.py b/django_redis/cache.py index d26c33fa..f7b943a3 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -185,6 +185,74 @@ def close(self, **kwargs): def touch(self, *args, **kwargs): return self.client.touch(*args, **kwargs) + @omit_exception + def sadd(self, *args, **kwargs): + return self.client.sadd(*args, **kwargs) + + @omit_exception + def scard(self, *args, **kwargs): + return self.client.scard(*args, **kwargs) + + @omit_exception + def sdiff(self, *args, **kwargs): + return self.client.sdiff(*args, **kwargs) + + @omit_exception + def sdiffstore(self, *args, **kwargs): + return self.client.sdiffstore(*args, **kwargs) + + @omit_exception + def sinter(self, *args, **kwargs): + return self.client.sinter(*args, **kwargs) + + @omit_exception + def sinterstore(self, *args, **kwargs): + return self.client.sinterstore(*args, **kwargs) + + @omit_exception + def sismember(self, *args, **kwargs): + return self.client.sismember(*args, **kwargs) + + @omit_exception + def smembers(self, *args, **kwargs): + return self.client.smembers(*args, **kwargs) + + @omit_exception + def smove(self, *args, **kwargs): + return self.client.smove(*args, **kwargs) + + @omit_exception + def spop(self, *args, **kwargs): + return self.client.spop(*args, **kwargs) + + @omit_exception + def srandmember(self, *args, **kwargs): + return self.client.srandmember(*args, **kwargs) + + @omit_exception + def srem(self, *args, **kwargs): + return self.client.srem(*args, **kwargs) + + @omit_exception + def sscan(self, *args, **kwargs): + return self.client.sscan(*args, **kwargs) + + @omit_exception + def sscan_iter(self, *args, **kwargs): + return self.client.sscan_iter(*args, **kwargs) + + @omit_exception + def smismember(self, *args, **kwargs): + return self.client.smismember(*args, **kwargs) + + @omit_exception + def sunion(self, *args, **kwargs): + return self.client.sunion(*args, **kwargs) + + @omit_exception + def sunionstore(self, *args, **kwargs): + return self.client.sunionstore(*args, **kwargs) + @omit_exception def hset(self, *args, **kwargs): return self.client.hset(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index b9a5c1b0..6f51a4cd 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,7 +3,18 @@ import socket from collections import OrderedDict from contextlib import suppress -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -11,7 +22,7 @@ from django.utils.module_loading import import_string from redis import Redis from redis.exceptions import ConnectionError, ResponseError, TimeoutError -from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT +from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT, PatternT from django_redis import pool from django_redis.exceptions import CompressorError, ConnectionInterrupted @@ -66,6 +77,14 @@ def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None: def __contains__(self, key: KeyT) -> bool: return self.has_key(key) + def _has_compression_enabled(self) -> bool: + return ( + self._options.get( + "COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor" + ) + != "django_redis.compressors.identity.IdentityCompressor" + ) + def get_next_client_index( self, write: bool = True, tried: Optional[List[int]] = None ) -> int: @@ -778,6 +797,265 @@ def make_pattern( return CacheKey(self._backend.key_func(pattern, prefix, version_str)) + def sadd( + self, + key: KeyT, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + encoded_values = [self.encode(value) for value in values] + return int(client.sadd(key, *encoded_values)) + + def scard( + self, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return int(client.scard(key)) + + def sdiff( + self, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + nkeys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sdiff(*nkeys)} + + def sdiffstore( + self, + dest: KeyT, + *keys: KeyT, + version_dest: Optional[int] = None, + version_keys: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version_dest) + nkeys = [self.make_key(key, version=version_keys) for key in keys] + return int(client.sdiffstore(dest, *nkeys)) + + def sinter( + self, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + nkeys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sinter(*nkeys)} + + def sinterstore( + self, + dest: KeyT, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version) + nkeys = [self.make_key(key, version=version) for key in keys] + return int(client.sinterstore(dest, *nkeys)) + + def smismember( + self, + key: KeyT, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> List[bool]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + encoded_members = [self.encode(member) for member in members] + + return [bool(value) for value in client.smismember(key, *encoded_members)] + + def sismember( + self, + key: KeyT, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + member = self.encode(member) + return bool(client.sismember(key, member)) + + def smembers( + self, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return {self.decode(value) for value in client.smembers(key)} + + def smove( + self, + source: KeyT, + destination: KeyT, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=True) + + source = self.make_key(source, version=version) + destination = self.make_key(destination) + member = self.encode(member) + return bool(client.smove(source, destination, member)) + + def spop( + self, + key: KeyT, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + client = self.get_client(write=True) + + nkey = self.make_key(key, version=version) + result = client.spop(nkey, count) + if result is None: + return None + if isinstance(result, list): + return {self.decode(value) for value in result} + return self.decode(result) + + def srandmember( + self, + key: KeyT, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[List, Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + result = client.srandmember(key, count) + if result is None: + return None + if isinstance(result, list): + return [self.decode(value) for value in result] + return self.decode(result) + + def srem( + self, + key: KeyT, + *members: EncodableT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + nmembers = [self.encode(member) for member in members] + return int(client.srem(key, *nmembers)) + + def sscan( + self, + key: KeyT, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set[Any]: + if self._has_compression_enabled() and match: + err_msg = "Using match with compression is not supported." + raise ValueError(err_msg) + + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + + cursor, result = client.sscan( + key, + match=cast(PatternT, self.encode(match)) if match else None, + count=count, + ) + return {self.decode(value) for value in result} + + def sscan_iter( + self, + key: KeyT, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Iterator[Any]: + if self._has_compression_enabled() and match: + err_msg = "Using match with compression is not supported." + raise ValueError(err_msg) + + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + for value in client.sscan_iter( + key, + match=cast(PatternT, self.encode(match)) if match else None, + count=count, + ): + yield self.decode(value) + + def sunion( + self, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + nkeys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sunion(*nkeys)} + + def sunionstore( + self, + destination: Any, + *keys: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + destination = self.make_key(destination, version=version) + encoded_keys = [self.make_key(key, version=version) for key in keys] + return int(client.sunionstore(destination, *encoded_keys)) + def close(self) -> None: close_flag = self._options.get( "CLOSE_CONNECTION", diff --git a/django_redis/client/sharded.py b/django_redis/client/sharded.py index dbb1d200..6178dc94 100644 --- a/django_redis/client/sharded.py +++ b/django_redis/client/sharded.py @@ -1,9 +1,11 @@ import re from collections import OrderedDict from datetime import datetime -from typing import Union +from typing import Any, Iterator, Optional, Set, Union +from redis import Redis from redis.exceptions import ConnectionError +from redis.typing import KeyT from django_redis.client.default import DEFAULT_TIMEOUT, DefaultClient from django_redis.exceptions import ConnectionInterrupted @@ -258,7 +260,8 @@ def incr_version(self, key, delta=1, version=None, client=None): raise ConnectionInterrupted(connection=client) from e if value is None: - raise ValueError("Key '%s' not found" % key) + err_msg = f"Key '{key}' not found" + raise ValueError(err_msg) if isinstance(key, CacheKey): new_key = self.make_key(key.original_key(), version=version + delta) @@ -335,3 +338,148 @@ def touch(self, key, timeout=DEFAULT_TIMEOUT, version=None, client=None): def clear(self, client=None): for connection in self._serverdict.values(): connection.flushdb() + + def sadd( + self, + key: KeyT, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sadd(key, *values, version=version, client=client) + + def scard( + self, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().scard(key=key, version=version, client=client) + + def smembers( + self, + key: KeyT, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().smembers(key=key, version=version, client=client) + + def smove( + self, + source: KeyT, + destination: KeyT, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ): + if client is None: + source = self.make_key(source, version=version) + client = self.get_server(source) + destination = self.make_key(destination, version=version) + + return super().smove( + source=source, + destination=destination, + member=member, + version=version, + client=client, + ) + + def srem( + self, + key: KeyT, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().srem(key, *members, version=version, client=client) + + def sscan( + self, + key: KeyT, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set[Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sscan( + key=key, match=match, count=count, version=version, client=client + ) + + def sscan_iter( + self, + key: KeyT, + match: Optional[str] = None, + count: Optional[int] = 10, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Iterator[Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sscan_iter( + key=key, match=match, count=count, version=version, client=client + ) + + def srandmember( + self, + key: KeyT, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().srandmember(key=key, count=count, version=version, client=client) + + def sismember( + self, + key: KeyT, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().sismember(key, member, version=version, client=client) + + def spop( + self, + key: KeyT, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().spop(key=key, count=count, version=version, client=client) + + def smismember( + self, + key: KeyT, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> list[bool]: + if client is None: + key = self.make_key(key, version=version) + client = self.get_server(key) + return super().smismember(key, *members, version=version, client=client) diff --git a/django_redis/compressors/lz4.py b/django_redis/compressors/lz4.py index 32183321..940c96d5 100644 --- a/django_redis/compressors/lz4.py +++ b/django_redis/compressors/lz4.py @@ -16,5 +16,5 @@ def compress(self, value: bytes) -> bytes: def decompress(self, value: bytes) -> bytes: try: return _decompress(value) - except Exception as e: # noqa: BLE001 + except Exception as e: raise CompressorError from e diff --git a/tests/test_backend.py b/tests/test_backend.py index 4ff60983..8619931e 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -856,3 +856,160 @@ def test_hexists(self, cache: RedisCache): cache.hset("foo_hash5", "foo1", "bar1") assert cache.hexists("foo_hash5", "foo1") assert not cache.hexists("foo_hash5", "foo") + + def test_sadd(self, cache: RedisCache): + assert cache.sadd("foo", "bar") == 1 + assert cache.smembers("foo") == {"bar"} + + def test_scard(self, cache: RedisCache): + cache.sadd("foo", "bar", "bar2") + assert cache.scard("foo") == 2 + + def test_sdiff(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiff("foo1", "foo2") == {"bar1"} + + def test_sdiffstore(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiffstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sdiffstore_with_keys_version(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2", version=2) + cache.sadd("foo2", "bar2", "bar3", version=2) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sdiffstore_with_different_keys_versions_without_initial_set_in_version( + self, cache: RedisCache + ): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2", version=1) + cache.sadd("foo2", "bar2", "bar3", version=2) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 0 + + def test_sdiffstore_with_different_keys_versions_with_initial_set_in_version( + self, cache: RedisCache + ): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2", version=2) + cache.sadd("foo2", "bar2", "bar3", version=1) + assert cache.sdiffstore("foo3", "foo1", "foo2", version_keys=2) == 2 + + def test_sinter(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinter("foo1", "foo2") == {"bar2"} + + def test_interstore(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinterstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar2"} + + def test_sismember(self, cache: RedisCache): + cache.sadd("foo", "bar") + assert cache.sismember("foo", "bar") is True + assert cache.sismember("foo", "bar2") is False + + def test_smove(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.smove("foo1", "foo2", "bar1") is True + assert cache.smove("foo1", "foo2", "bar4") is False + assert cache.smembers("foo1") == {"bar2"} + assert cache.smembers("foo2") == {"bar1", "bar2", "bar3"} + + def test_spop_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo") in {"bar1", "bar2"} + assert cache.smembers("foo") in [{"bar1"}, {"bar2"}] + + def test_spop(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo", 1) in [{"bar1"}, {"bar2"}] + assert cache.smembers("foo") in [{"bar1"}, {"bar2"}] + + def test_srandmember_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo") in {"bar1", "bar2"} + + def test_srandmember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo", 1) in [["bar1"], ["bar2"]] + + def test_srem(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srem("foo", "bar1") == 1 + assert cache.srem("foo", "bar3") == 0 + + def test_sscan(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + items = cache.sscan("foo") + assert items == {"bar1", "bar2"} + + def test_sscan_with_match(self, cache: RedisCache): + if cache.client._has_compression_enabled(): + pytest.skip("Compression is enabled, sscan with match is not supported") + cache.sadd("foo", "bar1", "bar2", "zoo") + items = cache.sscan("foo", match="zoo") + assert items == {"zoo"} + + def test_sscan_iter(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + items = cache.sscan_iter("foo") + assert set(items) == {"bar1", "bar2"} + + def test_sscan_iter_with_match(self, cache: RedisCache): + if cache.client._has_compression_enabled(): + pytest.skip( + "Compression is enabled, sscan_iter with match is not supported" + ) + cache.sadd("foo", "bar1", "bar2", "zoo") + items = cache.sscan_iter("foo", match="bar*") + assert set(items) == {"bar1", "bar2"} + + def test_smismember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "bar3") + assert cache.smismember("foo", "bar1", "bar2", "xyz") == [True, True, False] + + def test_sunion(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunion("foo1", "foo2") == {"bar1", "bar2", "bar3"} + + def test_sunionstore(self, cache: RedisCache): + if isinstance(cache.client, ShardClient): + pytest.skip("ShardClient doesn't support get_client") + + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunionstore("foo3", "foo1", "foo2") == 3 + assert cache.smembers("foo3") == {"bar1", "bar2", "bar3"}