diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 2fa8a988..de2f9bc9 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -1,6 +1,10 @@ --- title: Change log -description: Change log of all fakeredis releases +description: Changelog of all fakeredis releases +tags: + - changelog + - release-notes +toc_depth: 2 --- ## v2.25.0 @@ -10,6 +14,9 @@ description: Change log of all fakeredis releases - Implement support for hash expiration related commands @j00bar #328 - `HEXPIRE`, `HEXPIREAT`, `HEXPIRETIME`, `HPERSIST`, `HPEXPIRE`, `HPEXPIREAT`, `HPEXPIRETIME`, `HPTTL`, `HTTL`, - Implement support for `SORT_RO` #325, `EXPIRETIME` #323, and `PEXPIRETIME` #324 +- Support for creating a tcp server listening to multiple clients +- Testing against valkey 8.0 #333 +- Improve documentation #332 ## v2.24.1 diff --git a/docs/index.md b/docs/index.md index 7aa9d847..4fb46470 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,11 @@ -## fakeredis: A python implementation of redis server +--- +toc: + toc_depth: 3 +--- + +fakeredis: A python implementation of redis server +================================================= + FakeRedis is a pure-Python implementation of the Redis key-value store. @@ -30,6 +37,27 @@ pip install fakeredis[probabilistic,json] ## Support for RedisJSON and BloomFil ## How to Use +### Start a server on a thread + +It is possible to start a server on a thread and use it as a connect to it as you would a real redis server. + +```python +from threading import Thread +from fakeredis import TcpFakeServer + +server_address = ("127.0.0.1", 6379) +server = TcpFakeServer(server_address) +t = Thread(target=server.serve_forever, daemon=True) +t.start() + +import redis + +r = redis.Redis(host=server_address[0], port=server_address[1]) +r.set("foo", "bar") +assert r.get("foo") == b"bar" + +``` + ### Use as a pytest fixture ```python @@ -196,11 +224,13 @@ from fastapi import Depends, FastAPI app = FastAPI() + async def get_redis() -> AsyncIterator[redis.Redis]: # Code to handle creating a redis connection goes here, for example async with redis.from_url("redis://localhost:6379") as client: # type: ignore[no-untyped-call] yield client + @app.get("/") async def root(redis_client: Annotated[redis.Redis, Depends(get_redis)]) -> Any: # Code that does something with redis goes here, for example: @@ -223,20 +253,24 @@ from redis import asyncio as redis from main import app, get_redis + @pytest_asyncio.fixture async def redis_client() -> AsyncIterator[redis.Redis]: async with fakeredis.FakeAsyncRedis() as client: yield client + @pytest_asyncio.fixture async def app_client(redis_client: redis.Redis) -> AsyncIterator[httpx.AsyncClient]: async def get_redis_override() -> redis.Redis: return redis_client + transport = httpx.ASGITransport(app=app) # type: ignore[arg-type] # https://github.com/encode/httpx/issues/3111 async with httpx.AsyncClient(transport=transport, base_url="http://test") as app_client: with mock.patch.dict(app.dependency_overrides, {get_redis: get_redis_override}): yield app_client + @pytest.mark.asyncio async def test_app(app_client: httpx.AsyncClient) -> None: response = await app_client.get("/") diff --git a/docs/overrides/partials/toc-item.html b/docs/overrides/partials/toc-item.html new file mode 100644 index 00000000..ca70d15f --- /dev/null +++ b/docs/overrides/partials/toc-item.html @@ -0,0 +1,20 @@ +
  • + + + {{ toc_item.title }} + + + + + {% if toc_item.children %} + + {% endif %} +
  • \ No newline at end of file diff --git a/fakeredis/__init__.py b/fakeredis/__init__.py index b395186c..4b4ce507 100644 --- a/fakeredis/__init__.py +++ b/fakeredis/__init__.py @@ -1,3 +1,5 @@ +import sys + from ._connection import ( FakeRedis, FakeStrictRedis, @@ -9,6 +11,15 @@ FakeConnection as FakeAsyncConnection, ) +if sys.version_info >= (3, 11): + from ._tcp_server import TcpFakeServer +else: + + class TcpFakeServer: + def __init__(self, *args, **kwargs): + raise NotImplementedError("TcpFakeServer is only available in Python 3.11+") + + try: from importlib import metadata except ImportError: # for Python < 3.8 @@ -28,4 +39,5 @@ "FakeConnection", "FakeAsyncRedis", "FakeAsyncConnection", + "TcpFakeServer", ] diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index ea070778..599717c4 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -83,13 +83,20 @@ def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) -> self._paused = False self._parser = self._parse_commands() self._parser.send(None) - self.version = server.version # Assigned elsewhere self._transaction: Optional[List[Any]] self._in_transaction: bool self._pubsub: int self._transaction_failed: bool + @property + def version(self) -> Tuple[int, ...]: + return self._server.version + + @property + def server_type(self) -> str: + return self._server.server_type + def put_response(self, msg: Any) -> None: """Put a response message into the queue of responses. diff --git a/fakeredis/_connection.py b/fakeredis/_connection.py index 4d6325ca..06369026 100644 --- a/fakeredis/_connection.py +++ b/fakeredis/_connection.py @@ -85,6 +85,7 @@ def __init__( *args: Any, server: Optional[FakeServer] = None, version: VersionType = (7,), + server_type: str = "redis", lua_modules: Optional[Set[str]] = None, **kwargs: Any, ) -> None: @@ -130,6 +131,7 @@ def __init__( "connection_class": FakeConnection, "server": server, "version": version, + "server_type": server_type, "lua_modules": lua_modules, } connection_kwargs.update({arg: kwds[arg] for arg in conn_pool_args if arg in kwds}) @@ -137,11 +139,14 @@ def __init__( kwds.pop("server", None) kwds.pop("connected", None) kwds.pop("version", None) + kwds.pop("server_type", None) kwds.pop("lua_modules", None) super().__init__(**kwds) @classmethod def from_url(cls, *args: Any, **kwargs: Any) -> Self: + kwargs.setdefault("version", "7.4") + kwargs.setdefault("server_type", "redis") pool = redis.ConnectionPool.from_url(*args, **kwargs) # Now override how it creates connections pool.connection_class = FakeConnection diff --git a/fakeredis/_msgs.py b/fakeredis/_msgs.py index 0484edc0..c6984a3d 100644 --- a/fakeredis/_msgs.py +++ b/fakeredis/_msgs.py @@ -51,6 +51,7 @@ LOG_INVALID_DEBUG_LEVEL_MSG = "ERR Invalid debug level." LUA_COMMAND_ARG_MSG6 = "ERR Lua redis() command arguments must be strings or integers" LUA_COMMAND_ARG_MSG = "ERR Lua redis lib command arguments must be strings or integers" +VALKEY_LUA_COMMAND_ARG_MSG = "Command arguments must be strings or integers script: {}" LUA_WRONG_NUMBER_ARGS_MSG = "ERR wrong number or type of arguments" SCRIPT_ERROR_MSG = "ERR Error running script (call to f_{}): @user_script:?: {}" RESTORE_KEY_EXISTS = "BUSYKEY Target key name already exists." @@ -100,9 +101,11 @@ TIMESERIES_KEY_EXISTS = "TSDB: key already exists" TIMESERIES_INVALID_DUPLICATE_POLICY = "TSDB: Unknown DUPLICATE_POLICY" TIMESERIES_KEY_DOES_NOT_EXIST = "TSDB: the key does not exist" +TIMESERIES_RULE_DOES_NOT_EXIST = "TSDB: compaction rule does not exist" TIMESERIES_RULE_EXISTS = "TSDB: the destination key already has a src rule" TIMESERIES_BAD_AGGREGATION_TYPE = "TSDB: Unknown aggregation type" TIMESERIES_INVALID_TIMESTAMP = "TSDB: invalid timestamp" +TIMESERIES_BAD_TIMESTAMP = "TSDB: Couldn't parse alignTimestamp" TIMESERIES_TIMESTAMP_OLDER_THAN_RETENTION = "TSDB: Timestamp is older than retention" TIMESERIES_TIMESTAMP_LOWER_THAN_MAX_V7 = ( "TSDB: timestamp must be equal to or higher than the maximum existing timestamp" diff --git a/fakeredis/_server.py b/fakeredis/_server.py index 609fa7cc..db32c32a 100644 --- a/fakeredis/_server.py +++ b/fakeredis/_server.py @@ -32,7 +32,7 @@ def _create_version(v: VersionType) -> Tuple[int, ...]: class FakeServer: _servers_map: Dict[str, "FakeServer"] = dict() - def __init__(self, version: VersionType = (7,)): + def __init__(self, version: VersionType = (7,), server_type: str = "redis") -> None: self.lock = threading.Lock() self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock)) # Maps channel/pattern to a weak set of sockets @@ -43,15 +43,18 @@ def __init__(self, version: VersionType = (7,)): self.connected = True # List of weakrefs to sockets that are being closed lazily self.closed_sockets: List[Any] = [] - self.version = _create_version(version) + self.version: Tuple[int, ...] = _create_version(version) + if server_type not in ("redis", "dragonfly", "valkey"): + raise ValueError(f"Unsupported server type: {server_type}") + self.server_type: str = server_type @staticmethod - def get_server(key: str, version: VersionType) -> "FakeServer": - return FakeServer._servers_map.setdefault(key, FakeServer(version=version)) + def get_server(key: str, version: VersionType, server_type: str) -> "FakeServer": + return FakeServer._servers_map.setdefault(key, FakeServer(version=version, server_type=server_type)) class FakeBaseConnectionMixin(object): - def __init__(self, *args: Any, version: VersionType = (7, 0), **kwargs: Any) -> None: + def __init__(self, *args: Any, version: VersionType, server_type: str, **kwargs: Any) -> None: self.client_name: Optional[str] = None self.server_key: str self._sock = None @@ -66,7 +69,7 @@ def __init__(self, *args: Any, version: VersionType = (7, 0), **kwargs: Any) -> else: host, port = kwargs.get("host"), kwargs.get("port") self.server_key = f"{host}:{port}" - self.server_key += f":v{version}" - self._server = FakeServer.get_server(self.server_key, version=version) + self.server_key += f":{server_type}:v{version}" + self._server = FakeServer.get_server(self.server_key, server_type=server_type, version=version) self._server.connected = connected super().__init__(*args, **kwargs) diff --git a/fakeredis/_tcp_server.py b/fakeredis/_tcp_server.py new file mode 100644 index 00000000..3743b81d --- /dev/null +++ b/fakeredis/_tcp_server.py @@ -0,0 +1,127 @@ +import logging +from dataclasses import dataclass +from itertools import count +from socketserver import ThreadingTCPServer, StreamRequestHandler +from typing import BinaryIO, Dict, Tuple + +from fakeredis import FakeRedis +from fakeredis import FakeServer + +LOGGER = logging.getLogger("fakeredis") + + +def to_bytes(value) -> bytes: + if isinstance(value, bytes): + return value + return str(value).encode() + + +@dataclass +class Client: + connection: FakeRedis + client_address: int + + +@dataclass +class Reader: + reader: BinaryIO + + def load_array(self, length: int): + array = [None] * length + for i in range(length): + array[i] = self.load() + return array + + def load(self): + line = self.reader.readline().strip() + match line[0:1], line[1:]: + case b"*", length: + return self.load_array(int(length)) + case b"$", length: + bulk_string = self.reader.read(int(length) + 2).strip() + if len(bulk_string) != int(length): + raise ValueError() + return bulk_string + case b":", value: + return int(value) + case b"+", value: + return value + case b"-", value: + return Exception(value) + case _: + return None + + +@dataclass +class Writer: + writer: BinaryIO + + def dump(self, value, dump_bulk=False): + if isinstance(value, int): + self.writer.write(f":{value}\r\n".encode()) + elif isinstance(value, (str, bytes)): + value = to_bytes(value) + if dump_bulk or b"\r" in value or b"\n" in value: + self.writer.write(b"$" + str(len(value)).encode() + b"\r\n" + value + b"\r\n") + else: + self.writer.write(b"+" + value + b"\r\n") + elif isinstance(value, (list, set)): + self.writer.write(f"*{len(value)}\r\n".encode()) + for item in value: + self.dump(item, dump_bulk=True) + elif value is None: + self.writer.write("$-1\r\n".encode()) + elif isinstance(value, Exception): + self.writer.write(f"-{value.args[0]}\r\n".encode()) + + +class TCPFakeRequestHandler(StreamRequestHandler): + + def setup(self) -> None: + super().setup() + if self.client_address in self.server.clients: + self.current_client = self.server.clients[self.client_address] + else: + self.current_client = Client( + connection=FakeRedis(server=self.server.fake_server), + client_address=self.client_address, + ) + self.reader = Reader(self.rfile) + self.writer = Writer(self.wfile) + self.server.clients[self.client_address] = self.current_client + + def handle(self): + while True: + try: + self.data = self.reader.load() + LOGGER.debug(f">>> {self.client_address[0]}: {self.data}") + res = self.current_client.connection.execute_command(*self.data) + LOGGER.debug(f"<<< {self.client_address[0]}: {res}") + self.writer.dump(res) + except Exception as e: + LOGGER.debug(f"!!! {self.client_address[0]}: {e}") + self.writer.dump(e) + break + + def finish(self) -> None: + del self.server.clients[self.current_client.client_address] + super().finish() + + +class TcpFakeServer(ThreadingTCPServer): + def __init__( + self, + server_address: Tuple[str | bytes | bytearray, int], + bind_and_activate: bool = True, + server_type: str = "redis", + server_version: Tuple[int, ...] = (7, 4), + ): + super().__init__(server_address, TCPFakeRequestHandler, bind_and_activate) + self.fake_server = FakeServer(server_type=server_type, version=server_version) + self.client_ids = count(0) + self.clients: Dict[int, FakeRedis] = dict() + + +if __name__ == "__main__": + server = TcpFakeServer(("localhost", 19000)) + server.serve_forever() diff --git a/fakeredis/aioredis.py b/fakeredis/aioredis.py index 3d2f7214..38b047b2 100644 --- a/fakeredis/aioredis.py +++ b/fakeredis/aioredis.py @@ -206,6 +206,7 @@ def __init__( server: Optional[_server.FakeServer] = None, connected: bool = True, version: VersionType = (7,), + server_type: str = "redis", lua_modules: Optional[Set[str]] = None, **kwargs: Any, ) -> None: @@ -230,6 +231,7 @@ def __init__( connection_class=FakeConnection, max_connections=max_connections, version=version, + server_type=server_type, lua_modules=lua_modules, ) connection_pool = redis_async.ConnectionPool(**connection_kwargs) # type:ignore @@ -256,6 +258,8 @@ def from_url(cls, url: str, **kwargs: Any) -> redis_async.Redis: self = super().from_url(url, **kwargs) pool = self.connection_pool # Now override how it creates connections pool.connection_class = FakeConnection + pool.connection_kwargs.setdefault("version", "7.4") + pool.connection_kwargs.setdefault("server_type", "redis") pool.connection_kwargs.pop("username", None) pool.connection_kwargs.pop("password", None) return self diff --git a/fakeredis/commands_mixins/scripting_mixin.py b/fakeredis/commands_mixins/scripting_mixin.py index 84cc9343..034a7cc8 100644 --- a/fakeredis/commands_mixins/scripting_mixin.py +++ b/fakeredis/commands_mixins/scripting_mixin.py @@ -79,6 +79,7 @@ class ScriptingCommandsMixin: def __init__(self, *args: Any, **kwargs: Any): self.script_cache: Dict[bytes, bytes] = dict() # Maps SHA1 to the script source + self.server_type: str self.version: Tuple[int] self.load_lua_modules = set() lua_modules_set: Set[str] = kwargs.pop("lua_modules", None) or set() @@ -100,9 +101,7 @@ def _convert_redis_arg(self, lua_runtime: LUA_MODULE.LuaRuntime, value: Any) -> elif type(value) in {int, float}: return "{:.17g}".format(value).encode() else: - # TODO: add the context - msg = msgs.LUA_COMMAND_ARG_MSG6 if self.version < (7,) else msgs.LUA_COMMAND_ARG_MSG - raise SimpleError(msg) + raise SimpleError(msgs.LUA_COMMAND_ARG_MSG) def _convert_redis_result(self, lua_runtime: LUA_MODULE.LuaRuntime, result: Any) -> Any: if isinstance(result, (bytes, int)): @@ -210,6 +209,13 @@ def eval(self, script: bytes, numkeys: int, *keys_and_args: bytes) -> Any: try: result = lua_runtime.execute(script) except SimpleError as ex: + if ex.value == msgs.LUA_COMMAND_ARG_MSG: + if self.version < (7,): + raise SimpleError(msgs.LUA_COMMAND_ARG_MSG6) + elif self.server_type == "valkey": + raise SimpleError(msgs.VALKEY_LUA_COMMAND_ARG_MSG.format(sha1.decode())) + else: + raise SimpleError(msgs.LUA_COMMAND_ARG_MSG) if self.version < (7,): raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex)) raise SimpleError(ex.value) diff --git a/fakeredis/stack/_timeseries_mixin.py b/fakeredis/stack/_timeseries_mixin.py index 173fd1b2..1444f803 100644 --- a/fakeredis/stack/_timeseries_mixin.py +++ b/fakeredis/stack/_timeseries_mixin.py @@ -189,7 +189,7 @@ def ts_del(self, key: CommandItem, from_ts: int, to_ts: int) -> bytes: @command( name="TS.CREATERULE", fixed=(Key(TimeSeries), Key(TimeSeries), bytes, bytes, Int), - repeat=(Int,), + repeat=(bytes,), flags=msgs.FLAG_DO_NOT_CREATE, ) def ts_createrule( @@ -199,18 +199,18 @@ def ts_createrule( _: bytes, aggregator: bytes, bucket_duration: int, - *args: int, + *args: bytes, ) -> SimpleString: if source_key.value is None: raise SimpleError(msgs.TIMESERIES_KEY_DOES_NOT_EXIST) if dest_key.value is None: raise SimpleError(msgs.TIMESERIES_KEY_DOES_NOT_EXIST) if len(args) > 1: - raise SimpleError(msgs.WRONG_ARGS_MSG6) + raise SimpleError(msgs.WRONG_ARGS_MSG6.format("ts.createrule")) try: align_timestamp = int(args[0]) if len(args) == 1 else 0 except ValueError: - raise SimpleError(msgs.WRONG_ARGS_MSG6) + raise SimpleError(msgs.TIMESERIES_BAD_TIMESTAMP) existing_rule = source_key.value.get_rule(dest_key.key) if existing_rule is not None: raise SimpleError(msgs.TIMESERIES_RULE_EXISTS) @@ -231,7 +231,7 @@ def ts_deleterule(self, source_key: CommandItem, dest_key: CommandItem) -> bytes raise SimpleError(msgs.TIMESERIES_KEY_DOES_NOT_EXIST) res: Optional[TimeSeriesRule] = source_key.value.get_rule(dest_key.key) if res is None: - raise SimpleError(msgs.NOT_FOUND_MSG) + raise SimpleError(msgs.TIMESERIES_RULE_DOES_NOT_EXIST) source_key.value.delete_rule(res) return OK diff --git a/test/conftest.py b/test/conftest.py index 3807679f..b5b0282d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -39,10 +39,11 @@ def real_redis_version() -> Tuple[str, Union[None, Tuple[int, ...]]]: @pytest_asyncio.fixture(name="fake_server") -def _fake_server(request) -> fakeredis.FakeServer: +def _fake_server(request, real_redis_version) -> fakeredis.FakeServer: + server_type, _ = real_redis_version min_server_marker = request.node.get_closest_marker("min_server") server_version = min_server_marker.args[0] if min_server_marker else "6.2" - server = fakeredis.FakeServer(version=server_version) + server = fakeredis.FakeServer(server_type=server_type, version=server_version) server.connected = request.node.get_closest_marker("disconnected") is None return server diff --git a/test/test_general.py b/test/test_general.py deleted file mode 100644 index 196e6b23..00000000 --- a/test/test_general.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest -import redis - -from test.testtools import raw_command - - -def test_asyncioio_is_used(): - """Redis 4.2+ has support for asyncio and should be preferred over aioredis""" - from fakeredis import aioredis - - assert not hasattr(aioredis, "__version__") - - -def test_unknown_command(r: redis.Redis): - with pytest.raises(redis.ResponseError): - raw_command(r, "0 3 3") diff --git a/test/test_hypothesis.py b/test/test_hypothesis.py index 7dd78aae..0e687ff9 100644 --- a/test/test_hypothesis.py +++ b/test/test_hypothesis.py @@ -361,8 +361,8 @@ class Machine(CommonMachine): self.command_strategy | self.command_strategy_redis7 if redis_ver >= (7,) else self.command_strategy ) - # hypothesis.settings.register_profile("debug", max_examples=10, verbosity=hypothesis.Verbosity.debug) - # hypothesis.settings.load_profile("debug") + hypothesis.settings.register_profile("debug", max_examples=10, verbosity=hypothesis.Verbosity.debug) + hypothesis.settings.load_profile("debug") hypothesis.stateful.run_state_machine_as_test(Machine) @@ -705,8 +705,3 @@ def mutated_commands(commands): | add_arg(x, args) | swap_args(x), ) - - -class TestFuzz(BaseTest): - command_strategy = mutated_commands(TestJoint.command_strategy) - command_strategy = command_strategy.filter(lambda command: command.testable) diff --git a/test/test_internals/__init__.py b/test/test_internals/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_extract_args.py b/test/test_internals/test_extract_args.py similarity index 100% rename from test/test_extract_args.py rename to test/test_internals/test_extract_args.py diff --git a/test/test_connection.py b/test/test_mixins/test_connection.py similarity index 99% rename from test/test_connection.py rename to test/test_mixins/test_connection.py index 7886b4e2..55299069 100644 --- a/test/test_connection.py +++ b/test/test_mixins/test_connection.py @@ -20,6 +20,11 @@ def test_echo(r: redis.Redis): assert r.echo("hello") == b"hello" +def test_unknown_command(r: redis.Redis): + with pytest.raises(redis.ResponseError): + raw_command(r, "0 3 3") + + @testtools.fake_only def test_time(r, mocker): fake_time = mocker.patch("time.time") diff --git a/test/test_scan.py b/test/test_mixins/test_scan.py similarity index 100% rename from test/test_scan.py rename to test/test_mixins/test_scan.py diff --git a/test/test_mixins/test_scripting.py b/test/test_mixins/test_scripting.py index d0298417..99738023 100644 --- a/test/test_mixins/test_scripting.py +++ b/test/test_mixins/test_scripting.py @@ -106,7 +106,8 @@ def test_script_help(r: redis.Redis): @pytest.mark.min_server("7.1") -def test_script_help71(r: redis.Redis): +@pytest.mark.unsupported_server_types("valkey") +def test_script_help73(r: redis.Redis): assert raw_command(r, "SCRIPT HELP") == [ b"SCRIPT [ [value] [opt] ...]. Subcommands are:", b"DEBUG (YES|SYNC|NO)", @@ -341,10 +342,21 @@ def test_eval_call_bool6(r: redis.Redis): @pytest.mark.min_server("7") -def test_eval_call_bool7(r: redis.Redis): +@pytest.mark.unsupported_server_types("valkey") +def test_eval_call_bool7_redis(r: redis.Redis): # Redis doesn't allow Lua bools to be passed to [p]call - with pytest.raises(redis.ResponseError, match=r"Lua redis lib command arguments must be strings or integers"): + with pytest.raises(redis.ResponseError) as exc_info: + r.eval('return redis.call("SET", KEYS[1], true)', 1, "testkey") + assert "Lua redis lib command arguments must be strings or integers" in str(exc_info.value) + + +@pytest.mark.min_server("7") +@pytest.mark.unsupported_server_types("redis") +def test_eval_call_bool7_valkey(r: redis.Redis): + # Redis doesn't allow Lua bools to be passed to [p]call + with pytest.raises(redis.ResponseError) as exc_info: r.eval('return redis.call("SET", KEYS[1], true)', 1, "testkey") + assert "Command arguments must be strings or integers script" in str(exc_info.value) def test_eval_return_error(r: redis.Redis): diff --git a/test/test_sortedset_commands.py b/test/test_mixins/test_sortedset_commands.py similarity index 100% rename from test/test_sortedset_commands.py rename to test/test_mixins/test_sortedset_commands.py diff --git a/test/test_zadd.py b/test/test_mixins/test_zadd.py similarity index 100% rename from test/test_zadd.py rename to test/test_mixins/test_zadd.py diff --git a/test/test_stack/test_timeseries.py b/test/test_stack/test_timeseries.py index e76ddea2..a7b0a89e 100644 --- a/test/test_stack/test_timeseries.py +++ b/test/test_stack/test_timeseries.py @@ -6,11 +6,12 @@ import redis from fakeredis import _msgs as msgs +from test.testtools import raw_command timeseries_tests = pytest.importorskip("probables") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_add_ts_close(r: redis.Redis): ts1 = r.ts().add(5, "*", 1) time.sleep(0.001) @@ -18,7 +19,44 @@ def test_add_ts_close(r: redis.Redis): assert abs(ts2 - ts1) < 5 -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.min_server("7") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") +def test_createrule_errors(r: redis.Redis): + timeseries = r.ts() + with pytest.raises(redis.ResponseError) as e: + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST + + timeseries.create("t1") + with pytest.raises(redis.ResponseError) as e: + timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10) + assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST + + timeseries.create("t2") + with pytest.raises(redis.ResponseError) as e: + raw_command(r, "TS.CREATERULE", "t1", "t2", "AGGREGATION", "sum", 10, 1, 2) + assert str(e.value) in msgs.WRONG_ARGS_MSG6.format("ts.createrule") + + with pytest.raises(redis.ResponseError) as e: + raw_command(r, "TS.CREATERULE", "t1", "t2", "AGGREGATION", "sum", 10, "20c") + assert str(e.value) == msgs.TIMESERIES_BAD_TIMESTAMP + + +@pytest.mark.min_server("7") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") +def test_deleterule_errors(r: redis.Redis): + timeseries = r.ts() + with pytest.raises(redis.ResponseError) as e: + timeseries.deleterule("t1", "t2") + assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST + + timeseries.create("t1") + with pytest.raises(redis.ResponseError) as e: + timeseries.deleterule("t1", "t2") + assert str(e.value) == msgs.TIMESERIES_RULE_DOES_NOT_EXIST + + +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_key_exist(r: redis.Redis): assert r.ts().create(1) with pytest.raises(redis.ResponseError) as e: @@ -26,14 +64,14 @@ def test_create_key_exist(r: redis.Redis): assert str(e.value) == msgs.TIMESERIES_KEY_EXISTS -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_bad_duplicate_policy(r: redis.Redis): with pytest.raises(redis.ResponseError) as e: assert r.ts().create(1, duplicate_policy="bad") assert str(e.value) == msgs.TIMESERIES_INVALID_DUPLICATE_POLICY -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create(r: redis.Redis): assert r.ts().create(1) assert r.ts().create(2, retention_msecs=5) @@ -49,7 +87,7 @@ def test_create(r: redis.Redis): assert 128 == info.get("chunk_size") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_duplicate_policy(r: redis.Redis): # Test for duplicate policy for duplicate_policy in ["block", "last", "first", "min", "max"]: @@ -59,7 +97,7 @@ def test_create_duplicate_policy(r: redis.Redis): assert duplicate_policy == info.get("duplicate_policy") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_alter(r: redis.Redis): assert r.ts().create(1) info = r.ts().info(1) @@ -80,7 +118,7 @@ def test_alter(r: redis.Redis): assert str(e.value) == "TSDB: CHUNK_SIZE value must be a multiple of 8 in the range [48 .. 1048576]" -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_alter_diplicate_policy(r: redis.Redis): assert r.ts().create(1) info = r.ts().info(1) @@ -91,7 +129,7 @@ def test_alter_diplicate_policy(r: redis.Redis): assert "min" == info.get("duplicate_policy") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_add(r: redis.Redis): assert 1 == r.ts().add(1, 1, 1) assert 2 == r.ts().add(2, 2, 3, retention_msecs=10) @@ -108,7 +146,7 @@ def test_add(r: redis.Redis): assert 128 == info.get("chunk_size") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_add_before_retention(r: redis.Redis): r.ts().create("time-serie-1", retention_msecs=1000) assert r.ts().add("time-serie-1", 10000, 10.0) @@ -117,7 +155,7 @@ def test_add_before_retention(r: redis.Redis): assert str(e.value) == msgs.TIMESERIES_TIMESTAMP_OLDER_THAN_RETENTION -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_add_before_last(r: redis.Redis): r.ts().create("time-serie-1", retention_msecs=1000) assert r.ts().add("time-serie-1", 100, 10.0) == 100 @@ -132,7 +170,7 @@ def test_add_before_last(r: redis.Redis): ) -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_add_duplicate_policy(r: redis.Redis): # Test for duplicate policy BLOCK assert 1 == r.ts().add("time-serie-add-ooo-block", 1, 5.0) @@ -161,13 +199,13 @@ def test_add_duplicate_policy(r: redis.Redis): assert 5.0 == r.ts().get("time-serie-add-ooo-min")[1] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_madd(r: redis.Redis): r.ts().create("a") assert [1, 2, 3] == r.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_incrby_decrby(r: redis.Redis): for _ in range(100): assert r.ts().incrby(1, 1) @@ -197,7 +235,7 @@ def test_incrby_decrby(r: redis.Redis): assert 128 == info.get("chunk_size") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_and_delete_rule(r: redis.Redis): # test rule creation time = 100 @@ -222,7 +260,7 @@ def test_create_and_delete_rule(r: redis.Redis): assert info2["source_key"] is None -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_del_range(r: redis.Redis): with pytest.raises(redis.ResponseError) as e: r.ts().delete("test", 0, 100) @@ -237,7 +275,7 @@ def test_del_range(r: redis.Redis): assert r.ts().delete(1, 60, 3) == 0 -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_range(r: redis.Redis): for i in range(100): r.ts().add(1, i, i % 7) @@ -254,7 +292,7 @@ def test_range(r: redis.Redis): # assert 20 == len(r.ts().range(1, 0, 500, aggregation_type="avg", bucket_size_msec=10)) TODO -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_range_advanced(r: redis.Redis): for i in range(100): r.ts().add(1, i, i % 7) @@ -277,7 +315,7 @@ def test_range_advanced(r: redis.Redis): assert res == [(0, pytest.approx(2.55, 0.1)), (10, 3.0)] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_range_latest(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -295,7 +333,7 @@ def test_range_latest(r: redis.Redis): assert timeseries.range("t2", 0, 9) == [(0, 4.0)] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_range_bucket_timestamp(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -320,7 +358,7 @@ def test_range_bucket_timestamp(r: redis.Redis): ) == [(20, 4.0), (60, 3.0), (80, 5.0)] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_range_empty(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -351,7 +389,7 @@ def test_range_empty(r: redis.Redis): assert res == resp2_expected -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_rev_range(r: redis.Redis): for i in range(100): r.ts().add(1, i, i % 7) @@ -379,8 +417,7 @@ def test_rev_range(r: redis.Redis): ] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_revrange_latest(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -396,7 +433,7 @@ def test_revrange_latest(r: redis.Redis): assert timeseries.revrange("t2", 0, 9, latest=True) == [(0, 4.0)] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_revrange_bucket_timestamp(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -421,7 +458,7 @@ def test_revrange_bucket_timestamp(r: redis.Redis): ) == [(20, 4.0), (60, 3.0), (80, 5.0)] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_revrange_empty(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -450,8 +487,7 @@ def test_revrange_empty(r: redis.Redis): ] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_mrange(r: redis.Redis): r.ts().create(1, labels={"Test": "This", "team": "ny"}) r.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) @@ -479,7 +515,7 @@ def test_mrange(r: redis.Redis): assert {"Test": "This", "team": "ny"} == res[0]["1"][0] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_multi_range_advanced(r: redis.Redis): r.ts().create(1, labels={"Test": "This", "team": "ny"}) r.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) @@ -535,8 +571,7 @@ def test_multi_range_advanced(r: redis.Redis): # assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_mrange_latest(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -560,8 +595,7 @@ def test_mrange_latest(r: redis.Redis): ] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_multi_reverse_range(r: redis.Redis): r.ts().create(1, labels={"Test": "This", "team": "ny"}) r.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"}) @@ -635,8 +669,7 @@ def test_multi_reverse_range(r: redis.Redis): # assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_mrevrange_latest(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -660,7 +693,7 @@ def test_mrevrange_latest(r: redis.Redis): ] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_get(r: redis.Redis): name = "test" r.ts().create(name) @@ -671,8 +704,7 @@ def test_get(r: redis.Redis): assert 4 == r.ts().get(name)[1] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_get_latest(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -686,7 +718,7 @@ def test_get_latest(r: redis.Redis): assert timeseries.get("t2", latest=True) == (0, 4.0) -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_mget_errors(r: redis.Redis): r.ts().create(1, labels={"Test": "This"}) r.ts().create(2, labels={"Test": "This", "Taste": "That"}) @@ -699,7 +731,7 @@ def test_mget_errors(r: redis.Redis): assert str(e.value) == "TSDB: failed parsing labels" -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_mget(r: redis.Redis): r.ts().create(1, labels={"Test": "This"}) r.ts().create(2, labels={"Test": "This", "Taste": "That"}) @@ -721,8 +753,7 @@ def test_mget(r: redis.Redis): assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_mget_latest(r: redis.Redis): timeseries = r.ts() timeseries.create("t1") @@ -738,7 +769,7 @@ def test_mget_latest(r: redis.Redis): assert res == [{"t2": [{}, 0, 4.0]}] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_info(r: redis.Redis): r.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = r.ts().info(1) @@ -746,7 +777,7 @@ def test_info(r: redis.Redis): assert info["labels"]["currentLabel"] == "currentData" -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def testInfoDuplicatePolicy(r: redis.Redis): r.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = r.ts().info(1) @@ -757,8 +788,7 @@ def testInfoDuplicatePolicy(r: redis.Redis): assert info.get("duplicate_policy") == "min" -@pytest.mark.onlynoncluster -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_query_index(r: redis.Redis): r.ts().create(1, labels={"Test": "This"}) r.ts().create(2, labels={"Test": "This", "Taste": "That"}) @@ -767,7 +797,7 @@ def test_query_index(r: redis.Redis): assert r.ts().queryindex(["Taste=That"]) == [2] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_pipeline(r: redis.Redis): pipeline = r.ts().pipeline() pipeline.create("with_pipeline") @@ -782,7 +812,7 @@ def test_pipeline(r: redis.Redis): assert r.ts().get("with_pipeline")[1] == 99 * 1.1 -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_uncompressed(r: redis.Redis): r.ts().create("compressed") r.ts().create("uncompressed", uncompressed=True) @@ -792,7 +822,7 @@ def test_uncompressed(r: redis.Redis): assert compressed_info["memory_usage"] != uncompressed_info["memory_usage"] -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_rule_green(r: redis.Redis): r.ts().create(1) r.ts().create(2) @@ -808,7 +838,7 @@ def test_create_rule_green(r: redis.Redis): assert info["source_key"] == b"1" -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_rule_bad_aggregator(r: redis.Redis): r.ts().create(1) r.ts().create(2) @@ -817,14 +847,14 @@ def test_create_rule_bad_aggregator(r: redis.Redis): assert str(e.value) == msgs.TIMESERIES_BAD_AGGREGATION_TYPE -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_rule_key_not_exist(r: redis.Redis): with pytest.raises(redis.ResponseError) as e: r.ts().createrule(1, 2, "avg", 100) assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_create_rule_with_rule_to_dest_key_exists(r: redis.Redis): r.ts().create(1) r.ts().create(2) diff --git a/test/test_stack/test_topk.py b/test/test_stack/test_topk.py index 26df13dd..9581df60 100644 --- a/test/test_stack/test_topk.py +++ b/test/test_stack/test_topk.py @@ -4,7 +4,7 @@ topk_tests = pytest.importorskip("probables") -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_topk_incrby(r: redis.Redis): assert r.topk().reserve("topk", 3, 10, 3, 1) assert [None, None, None] == r.topk().incrby("topk", ["bar", "baz", "42"], [3, 6, 2]) @@ -13,7 +13,7 @@ def test_topk_incrby(r: redis.Redis): assert [3, 6, 10, 4, 0] == r.topk().count("topk", "bar", "baz", "42", "xyzzy", 4) -@pytest.mark.unsupported_server_types("dragonfly") +@pytest.mark.unsupported_server_types("dragonfly", "valkey") def test_topk(r: redis.Redis): # test list with empty buckets assert r.topk().reserve("topk", 3, 50, 4, 0.9) diff --git a/test/test_tcp_server/__init__.py b/test/test_tcp_server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_tcp_server/test_connectivity.py b/test/test_tcp_server/test_connectivity.py new file mode 100644 index 00000000..414ed587 --- /dev/null +++ b/test/test_tcp_server/test_connectivity.py @@ -0,0 +1,23 @@ +import sys +import time +from threading import Thread + +import pytest +import redis + +from fakeredis import TcpFakeServer + +if sys.version_info < (3, 11): + pytest.skip("TcpFakeServer is only available in Python 3.11+", allow_module_level=True) + + +def test_tcp_server_started(): + server_address = ("127.0.0.1", 19000) + server = TcpFakeServer(server_address) + t = Thread(target=server.serve_forever, daemon=True) + t.start() + time.sleep(0.1) + r = redis.Redis(host=server_address[0], port=server_address[1]) + r.set("foo", "bar") + assert r.get("foo") == b"bar" + server.shutdown()