diff --git a/docs/about/changelog.md b/docs/about/changelog.md
index 2fa8a988..de2f9bc9 100644
--- a/docs/about/changelog.md
+++ b/docs/about/changelog.md
@@ -1,6 +1,10 @@
---
title: Change log
-description: Change log of all fakeredis releases
+description: Changelog of all fakeredis releases
+tags:
+ - changelog
+ - release-notes
+toc_depth: 2
---
## v2.25.0
@@ -10,6 +14,9 @@ description: Change log of all fakeredis releases
- Implement support for hash expiration related commands @j00bar #328
- `HEXPIRE`, `HEXPIREAT`, `HEXPIRETIME`, `HPERSIST`, `HPEXPIRE`, `HPEXPIREAT`, `HPEXPIRETIME`, `HPTTL`, `HTTL`,
- Implement support for `SORT_RO` #325, `EXPIRETIME` #323, and `PEXPIRETIME` #324
+- Support for creating a tcp server listening to multiple clients
+- Testing against valkey 8.0 #333
+- Improve documentation #332
## v2.24.1
diff --git a/docs/index.md b/docs/index.md
index 7aa9d847..4fb46470 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -1,4 +1,11 @@
-## fakeredis: A python implementation of redis server
+---
+toc:
+ toc_depth: 3
+---
+
+fakeredis: A python implementation of redis server
+=================================================
+
FakeRedis is a pure-Python implementation of the Redis key-value store.
@@ -30,6 +37,27 @@ pip install fakeredis[probabilistic,json] ## Support for RedisJSON and BloomFil
## How to Use
+### Start a server on a thread
+
+It is possible to start a server on a thread and use it as a connect to it as you would a real redis server.
+
+```python
+from threading import Thread
+from fakeredis import TcpFakeServer
+
+server_address = ("127.0.0.1", 6379)
+server = TcpFakeServer(server_address)
+t = Thread(target=server.serve_forever, daemon=True)
+t.start()
+
+import redis
+
+r = redis.Redis(host=server_address[0], port=server_address[1])
+r.set("foo", "bar")
+assert r.get("foo") == b"bar"
+
+```
+
### Use as a pytest fixture
```python
@@ -196,11 +224,13 @@ from fastapi import Depends, FastAPI
app = FastAPI()
+
async def get_redis() -> AsyncIterator[redis.Redis]:
# Code to handle creating a redis connection goes here, for example
async with redis.from_url("redis://localhost:6379") as client: # type: ignore[no-untyped-call]
yield client
+
@app.get("/")
async def root(redis_client: Annotated[redis.Redis, Depends(get_redis)]) -> Any:
# Code that does something with redis goes here, for example:
@@ -223,20 +253,24 @@ from redis import asyncio as redis
from main import app, get_redis
+
@pytest_asyncio.fixture
async def redis_client() -> AsyncIterator[redis.Redis]:
async with fakeredis.FakeAsyncRedis() as client:
yield client
+
@pytest_asyncio.fixture
async def app_client(redis_client: redis.Redis) -> AsyncIterator[httpx.AsyncClient]:
async def get_redis_override() -> redis.Redis:
return redis_client
+
transport = httpx.ASGITransport(app=app) # type: ignore[arg-type] # https://github.com/encode/httpx/issues/3111
async with httpx.AsyncClient(transport=transport, base_url="http://test") as app_client:
with mock.patch.dict(app.dependency_overrides, {get_redis: get_redis_override}):
yield app_client
+
@pytest.mark.asyncio
async def test_app(app_client: httpx.AsyncClient) -> None:
response = await app_client.get("/")
diff --git a/docs/overrides/partials/toc-item.html b/docs/overrides/partials/toc-item.html
new file mode 100644
index 00000000..ca70d15f
--- /dev/null
+++ b/docs/overrides/partials/toc-item.html
@@ -0,0 +1,20 @@
+
+
+
+ {{ toc_item.title }}
+
+
+
+
+ {% if toc_item.children %}
+
+
+ {% for toc_item in toc_item.children %}
+ {% if not page.meta.toc_depth or toc_item.level <= page.meta.toc_depth %}
+ {% include "partials/toc-item.html" %}
+ {% endif %}
+ {% endfor %}
+
+
+ {% endif %}
+
\ No newline at end of file
diff --git a/fakeredis/__init__.py b/fakeredis/__init__.py
index b395186c..4b4ce507 100644
--- a/fakeredis/__init__.py
+++ b/fakeredis/__init__.py
@@ -1,3 +1,5 @@
+import sys
+
from ._connection import (
FakeRedis,
FakeStrictRedis,
@@ -9,6 +11,15 @@
FakeConnection as FakeAsyncConnection,
)
+if sys.version_info >= (3, 11):
+ from ._tcp_server import TcpFakeServer
+else:
+
+ class TcpFakeServer:
+ def __init__(self, *args, **kwargs):
+ raise NotImplementedError("TcpFakeServer is only available in Python 3.11+")
+
+
try:
from importlib import metadata
except ImportError: # for Python < 3.8
@@ -28,4 +39,5 @@
"FakeConnection",
"FakeAsyncRedis",
"FakeAsyncConnection",
+ "TcpFakeServer",
]
diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py
index ea070778..599717c4 100644
--- a/fakeredis/_basefakesocket.py
+++ b/fakeredis/_basefakesocket.py
@@ -83,13 +83,20 @@ def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) ->
self._paused = False
self._parser = self._parse_commands()
self._parser.send(None)
- self.version = server.version
# Assigned elsewhere
self._transaction: Optional[List[Any]]
self._in_transaction: bool
self._pubsub: int
self._transaction_failed: bool
+ @property
+ def version(self) -> Tuple[int, ...]:
+ return self._server.version
+
+ @property
+ def server_type(self) -> str:
+ return self._server.server_type
+
def put_response(self, msg: Any) -> None:
"""Put a response message into the queue of responses.
diff --git a/fakeredis/_connection.py b/fakeredis/_connection.py
index 4d6325ca..06369026 100644
--- a/fakeredis/_connection.py
+++ b/fakeredis/_connection.py
@@ -85,6 +85,7 @@ def __init__(
*args: Any,
server: Optional[FakeServer] = None,
version: VersionType = (7,),
+ server_type: str = "redis",
lua_modules: Optional[Set[str]] = None,
**kwargs: Any,
) -> None:
@@ -130,6 +131,7 @@ def __init__(
"connection_class": FakeConnection,
"server": server,
"version": version,
+ "server_type": server_type,
"lua_modules": lua_modules,
}
connection_kwargs.update({arg: kwds[arg] for arg in conn_pool_args if arg in kwds})
@@ -137,11 +139,14 @@ def __init__(
kwds.pop("server", None)
kwds.pop("connected", None)
kwds.pop("version", None)
+ kwds.pop("server_type", None)
kwds.pop("lua_modules", None)
super().__init__(**kwds)
@classmethod
def from_url(cls, *args: Any, **kwargs: Any) -> Self:
+ kwargs.setdefault("version", "7.4")
+ kwargs.setdefault("server_type", "redis")
pool = redis.ConnectionPool.from_url(*args, **kwargs)
# Now override how it creates connections
pool.connection_class = FakeConnection
diff --git a/fakeredis/_msgs.py b/fakeredis/_msgs.py
index 0484edc0..c6984a3d 100644
--- a/fakeredis/_msgs.py
+++ b/fakeredis/_msgs.py
@@ -51,6 +51,7 @@
LOG_INVALID_DEBUG_LEVEL_MSG = "ERR Invalid debug level."
LUA_COMMAND_ARG_MSG6 = "ERR Lua redis() command arguments must be strings or integers"
LUA_COMMAND_ARG_MSG = "ERR Lua redis lib command arguments must be strings or integers"
+VALKEY_LUA_COMMAND_ARG_MSG = "Command arguments must be strings or integers script: {}"
LUA_WRONG_NUMBER_ARGS_MSG = "ERR wrong number or type of arguments"
SCRIPT_ERROR_MSG = "ERR Error running script (call to f_{}): @user_script:?: {}"
RESTORE_KEY_EXISTS = "BUSYKEY Target key name already exists."
@@ -100,9 +101,11 @@
TIMESERIES_KEY_EXISTS = "TSDB: key already exists"
TIMESERIES_INVALID_DUPLICATE_POLICY = "TSDB: Unknown DUPLICATE_POLICY"
TIMESERIES_KEY_DOES_NOT_EXIST = "TSDB: the key does not exist"
+TIMESERIES_RULE_DOES_NOT_EXIST = "TSDB: compaction rule does not exist"
TIMESERIES_RULE_EXISTS = "TSDB: the destination key already has a src rule"
TIMESERIES_BAD_AGGREGATION_TYPE = "TSDB: Unknown aggregation type"
TIMESERIES_INVALID_TIMESTAMP = "TSDB: invalid timestamp"
+TIMESERIES_BAD_TIMESTAMP = "TSDB: Couldn't parse alignTimestamp"
TIMESERIES_TIMESTAMP_OLDER_THAN_RETENTION = "TSDB: Timestamp is older than retention"
TIMESERIES_TIMESTAMP_LOWER_THAN_MAX_V7 = (
"TSDB: timestamp must be equal to or higher than the maximum existing timestamp"
diff --git a/fakeredis/_server.py b/fakeredis/_server.py
index 609fa7cc..db32c32a 100644
--- a/fakeredis/_server.py
+++ b/fakeredis/_server.py
@@ -32,7 +32,7 @@ def _create_version(v: VersionType) -> Tuple[int, ...]:
class FakeServer:
_servers_map: Dict[str, "FakeServer"] = dict()
- def __init__(self, version: VersionType = (7,)):
+ def __init__(self, version: VersionType = (7,), server_type: str = "redis") -> None:
self.lock = threading.Lock()
self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock))
# Maps channel/pattern to a weak set of sockets
@@ -43,15 +43,18 @@ def __init__(self, version: VersionType = (7,)):
self.connected = True
# List of weakrefs to sockets that are being closed lazily
self.closed_sockets: List[Any] = []
- self.version = _create_version(version)
+ self.version: Tuple[int, ...] = _create_version(version)
+ if server_type not in ("redis", "dragonfly", "valkey"):
+ raise ValueError(f"Unsupported server type: {server_type}")
+ self.server_type: str = server_type
@staticmethod
- def get_server(key: str, version: VersionType) -> "FakeServer":
- return FakeServer._servers_map.setdefault(key, FakeServer(version=version))
+ def get_server(key: str, version: VersionType, server_type: str) -> "FakeServer":
+ return FakeServer._servers_map.setdefault(key, FakeServer(version=version, server_type=server_type))
class FakeBaseConnectionMixin(object):
- def __init__(self, *args: Any, version: VersionType = (7, 0), **kwargs: Any) -> None:
+ def __init__(self, *args: Any, version: VersionType, server_type: str, **kwargs: Any) -> None:
self.client_name: Optional[str] = None
self.server_key: str
self._sock = None
@@ -66,7 +69,7 @@ def __init__(self, *args: Any, version: VersionType = (7, 0), **kwargs: Any) ->
else:
host, port = kwargs.get("host"), kwargs.get("port")
self.server_key = f"{host}:{port}"
- self.server_key += f":v{version}"
- self._server = FakeServer.get_server(self.server_key, version=version)
+ self.server_key += f":{server_type}:v{version}"
+ self._server = FakeServer.get_server(self.server_key, server_type=server_type, version=version)
self._server.connected = connected
super().__init__(*args, **kwargs)
diff --git a/fakeredis/_tcp_server.py b/fakeredis/_tcp_server.py
new file mode 100644
index 00000000..3743b81d
--- /dev/null
+++ b/fakeredis/_tcp_server.py
@@ -0,0 +1,127 @@
+import logging
+from dataclasses import dataclass
+from itertools import count
+from socketserver import ThreadingTCPServer, StreamRequestHandler
+from typing import BinaryIO, Dict, Tuple
+
+from fakeredis import FakeRedis
+from fakeredis import FakeServer
+
+LOGGER = logging.getLogger("fakeredis")
+
+
+def to_bytes(value) -> bytes:
+ if isinstance(value, bytes):
+ return value
+ return str(value).encode()
+
+
+@dataclass
+class Client:
+ connection: FakeRedis
+ client_address: int
+
+
+@dataclass
+class Reader:
+ reader: BinaryIO
+
+ def load_array(self, length: int):
+ array = [None] * length
+ for i in range(length):
+ array[i] = self.load()
+ return array
+
+ def load(self):
+ line = self.reader.readline().strip()
+ match line[0:1], line[1:]:
+ case b"*", length:
+ return self.load_array(int(length))
+ case b"$", length:
+ bulk_string = self.reader.read(int(length) + 2).strip()
+ if len(bulk_string) != int(length):
+ raise ValueError()
+ return bulk_string
+ case b":", value:
+ return int(value)
+ case b"+", value:
+ return value
+ case b"-", value:
+ return Exception(value)
+ case _:
+ return None
+
+
+@dataclass
+class Writer:
+ writer: BinaryIO
+
+ def dump(self, value, dump_bulk=False):
+ if isinstance(value, int):
+ self.writer.write(f":{value}\r\n".encode())
+ elif isinstance(value, (str, bytes)):
+ value = to_bytes(value)
+ if dump_bulk or b"\r" in value or b"\n" in value:
+ self.writer.write(b"$" + str(len(value)).encode() + b"\r\n" + value + b"\r\n")
+ else:
+ self.writer.write(b"+" + value + b"\r\n")
+ elif isinstance(value, (list, set)):
+ self.writer.write(f"*{len(value)}\r\n".encode())
+ for item in value:
+ self.dump(item, dump_bulk=True)
+ elif value is None:
+ self.writer.write("$-1\r\n".encode())
+ elif isinstance(value, Exception):
+ self.writer.write(f"-{value.args[0]}\r\n".encode())
+
+
+class TCPFakeRequestHandler(StreamRequestHandler):
+
+ def setup(self) -> None:
+ super().setup()
+ if self.client_address in self.server.clients:
+ self.current_client = self.server.clients[self.client_address]
+ else:
+ self.current_client = Client(
+ connection=FakeRedis(server=self.server.fake_server),
+ client_address=self.client_address,
+ )
+ self.reader = Reader(self.rfile)
+ self.writer = Writer(self.wfile)
+ self.server.clients[self.client_address] = self.current_client
+
+ def handle(self):
+ while True:
+ try:
+ self.data = self.reader.load()
+ LOGGER.debug(f">>> {self.client_address[0]}: {self.data}")
+ res = self.current_client.connection.execute_command(*self.data)
+ LOGGER.debug(f"<<< {self.client_address[0]}: {res}")
+ self.writer.dump(res)
+ except Exception as e:
+ LOGGER.debug(f"!!! {self.client_address[0]}: {e}")
+ self.writer.dump(e)
+ break
+
+ def finish(self) -> None:
+ del self.server.clients[self.current_client.client_address]
+ super().finish()
+
+
+class TcpFakeServer(ThreadingTCPServer):
+ def __init__(
+ self,
+ server_address: Tuple[str | bytes | bytearray, int],
+ bind_and_activate: bool = True,
+ server_type: str = "redis",
+ server_version: Tuple[int, ...] = (7, 4),
+ ):
+ super().__init__(server_address, TCPFakeRequestHandler, bind_and_activate)
+ self.fake_server = FakeServer(server_type=server_type, version=server_version)
+ self.client_ids = count(0)
+ self.clients: Dict[int, FakeRedis] = dict()
+
+
+if __name__ == "__main__":
+ server = TcpFakeServer(("localhost", 19000))
+ server.serve_forever()
diff --git a/fakeredis/aioredis.py b/fakeredis/aioredis.py
index 3d2f7214..38b047b2 100644
--- a/fakeredis/aioredis.py
+++ b/fakeredis/aioredis.py
@@ -206,6 +206,7 @@ def __init__(
server: Optional[_server.FakeServer] = None,
connected: bool = True,
version: VersionType = (7,),
+ server_type: str = "redis",
lua_modules: Optional[Set[str]] = None,
**kwargs: Any,
) -> None:
@@ -230,6 +231,7 @@ def __init__(
connection_class=FakeConnection,
max_connections=max_connections,
version=version,
+ server_type=server_type,
lua_modules=lua_modules,
)
connection_pool = redis_async.ConnectionPool(**connection_kwargs) # type:ignore
@@ -256,6 +258,8 @@ def from_url(cls, url: str, **kwargs: Any) -> redis_async.Redis:
self = super().from_url(url, **kwargs)
pool = self.connection_pool # Now override how it creates connections
pool.connection_class = FakeConnection
+ pool.connection_kwargs.setdefault("version", "7.4")
+ pool.connection_kwargs.setdefault("server_type", "redis")
pool.connection_kwargs.pop("username", None)
pool.connection_kwargs.pop("password", None)
return self
diff --git a/fakeredis/commands_mixins/scripting_mixin.py b/fakeredis/commands_mixins/scripting_mixin.py
index 84cc9343..034a7cc8 100644
--- a/fakeredis/commands_mixins/scripting_mixin.py
+++ b/fakeredis/commands_mixins/scripting_mixin.py
@@ -79,6 +79,7 @@ class ScriptingCommandsMixin:
def __init__(self, *args: Any, **kwargs: Any):
self.script_cache: Dict[bytes, bytes] = dict() # Maps SHA1 to the script source
+ self.server_type: str
self.version: Tuple[int]
self.load_lua_modules = set()
lua_modules_set: Set[str] = kwargs.pop("lua_modules", None) or set()
@@ -100,9 +101,7 @@ def _convert_redis_arg(self, lua_runtime: LUA_MODULE.LuaRuntime, value: Any) ->
elif type(value) in {int, float}:
return "{:.17g}".format(value).encode()
else:
- # TODO: add the context
- msg = msgs.LUA_COMMAND_ARG_MSG6 if self.version < (7,) else msgs.LUA_COMMAND_ARG_MSG
- raise SimpleError(msg)
+ raise SimpleError(msgs.LUA_COMMAND_ARG_MSG)
def _convert_redis_result(self, lua_runtime: LUA_MODULE.LuaRuntime, result: Any) -> Any:
if isinstance(result, (bytes, int)):
@@ -210,6 +209,13 @@ def eval(self, script: bytes, numkeys: int, *keys_and_args: bytes) -> Any:
try:
result = lua_runtime.execute(script)
except SimpleError as ex:
+ if ex.value == msgs.LUA_COMMAND_ARG_MSG:
+ if self.version < (7,):
+ raise SimpleError(msgs.LUA_COMMAND_ARG_MSG6)
+ elif self.server_type == "valkey":
+ raise SimpleError(msgs.VALKEY_LUA_COMMAND_ARG_MSG.format(sha1.decode()))
+ else:
+ raise SimpleError(msgs.LUA_COMMAND_ARG_MSG)
if self.version < (7,):
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))
raise SimpleError(ex.value)
diff --git a/fakeredis/stack/_timeseries_mixin.py b/fakeredis/stack/_timeseries_mixin.py
index 173fd1b2..1444f803 100644
--- a/fakeredis/stack/_timeseries_mixin.py
+++ b/fakeredis/stack/_timeseries_mixin.py
@@ -189,7 +189,7 @@ def ts_del(self, key: CommandItem, from_ts: int, to_ts: int) -> bytes:
@command(
name="TS.CREATERULE",
fixed=(Key(TimeSeries), Key(TimeSeries), bytes, bytes, Int),
- repeat=(Int,),
+ repeat=(bytes,),
flags=msgs.FLAG_DO_NOT_CREATE,
)
def ts_createrule(
@@ -199,18 +199,18 @@ def ts_createrule(
_: bytes,
aggregator: bytes,
bucket_duration: int,
- *args: int,
+ *args: bytes,
) -> SimpleString:
if source_key.value is None:
raise SimpleError(msgs.TIMESERIES_KEY_DOES_NOT_EXIST)
if dest_key.value is None:
raise SimpleError(msgs.TIMESERIES_KEY_DOES_NOT_EXIST)
if len(args) > 1:
- raise SimpleError(msgs.WRONG_ARGS_MSG6)
+ raise SimpleError(msgs.WRONG_ARGS_MSG6.format("ts.createrule"))
try:
align_timestamp = int(args[0]) if len(args) == 1 else 0
except ValueError:
- raise SimpleError(msgs.WRONG_ARGS_MSG6)
+ raise SimpleError(msgs.TIMESERIES_BAD_TIMESTAMP)
existing_rule = source_key.value.get_rule(dest_key.key)
if existing_rule is not None:
raise SimpleError(msgs.TIMESERIES_RULE_EXISTS)
@@ -231,7 +231,7 @@ def ts_deleterule(self, source_key: CommandItem, dest_key: CommandItem) -> bytes
raise SimpleError(msgs.TIMESERIES_KEY_DOES_NOT_EXIST)
res: Optional[TimeSeriesRule] = source_key.value.get_rule(dest_key.key)
if res is None:
- raise SimpleError(msgs.NOT_FOUND_MSG)
+ raise SimpleError(msgs.TIMESERIES_RULE_DOES_NOT_EXIST)
source_key.value.delete_rule(res)
return OK
diff --git a/test/conftest.py b/test/conftest.py
index 3807679f..b5b0282d 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -39,10 +39,11 @@ def real_redis_version() -> Tuple[str, Union[None, Tuple[int, ...]]]:
@pytest_asyncio.fixture(name="fake_server")
-def _fake_server(request) -> fakeredis.FakeServer:
+def _fake_server(request, real_redis_version) -> fakeredis.FakeServer:
+ server_type, _ = real_redis_version
min_server_marker = request.node.get_closest_marker("min_server")
server_version = min_server_marker.args[0] if min_server_marker else "6.2"
- server = fakeredis.FakeServer(version=server_version)
+ server = fakeredis.FakeServer(server_type=server_type, version=server_version)
server.connected = request.node.get_closest_marker("disconnected") is None
return server
diff --git a/test/test_general.py b/test/test_general.py
deleted file mode 100644
index 196e6b23..00000000
--- a/test/test_general.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import pytest
-import redis
-
-from test.testtools import raw_command
-
-
-def test_asyncioio_is_used():
- """Redis 4.2+ has support for asyncio and should be preferred over aioredis"""
- from fakeredis import aioredis
-
- assert not hasattr(aioredis, "__version__")
-
-
-def test_unknown_command(r: redis.Redis):
- with pytest.raises(redis.ResponseError):
- raw_command(r, "0 3 3")
diff --git a/test/test_hypothesis.py b/test/test_hypothesis.py
index 7dd78aae..0e687ff9 100644
--- a/test/test_hypothesis.py
+++ b/test/test_hypothesis.py
@@ -361,8 +361,8 @@ class Machine(CommonMachine):
self.command_strategy | self.command_strategy_redis7 if redis_ver >= (7,) else self.command_strategy
)
- # hypothesis.settings.register_profile("debug", max_examples=10, verbosity=hypothesis.Verbosity.debug)
- # hypothesis.settings.load_profile("debug")
+ hypothesis.settings.register_profile("debug", max_examples=10, verbosity=hypothesis.Verbosity.debug)
+ hypothesis.settings.load_profile("debug")
hypothesis.stateful.run_state_machine_as_test(Machine)
@@ -705,8 +705,3 @@ def mutated_commands(commands):
| add_arg(x, args)
| swap_args(x),
)
-
-
-class TestFuzz(BaseTest):
- command_strategy = mutated_commands(TestJoint.command_strategy)
- command_strategy = command_strategy.filter(lambda command: command.testable)
diff --git a/test/test_internals/__init__.py b/test/test_internals/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/test_extract_args.py b/test/test_internals/test_extract_args.py
similarity index 100%
rename from test/test_extract_args.py
rename to test/test_internals/test_extract_args.py
diff --git a/test/test_connection.py b/test/test_mixins/test_connection.py
similarity index 99%
rename from test/test_connection.py
rename to test/test_mixins/test_connection.py
index 7886b4e2..55299069 100644
--- a/test/test_connection.py
+++ b/test/test_mixins/test_connection.py
@@ -20,6 +20,11 @@ def test_echo(r: redis.Redis):
assert r.echo("hello") == b"hello"
+def test_unknown_command(r: redis.Redis):
+ with pytest.raises(redis.ResponseError):
+ raw_command(r, "0 3 3")
+
+
@testtools.fake_only
def test_time(r, mocker):
fake_time = mocker.patch("time.time")
diff --git a/test/test_scan.py b/test/test_mixins/test_scan.py
similarity index 100%
rename from test/test_scan.py
rename to test/test_mixins/test_scan.py
diff --git a/test/test_mixins/test_scripting.py b/test/test_mixins/test_scripting.py
index d0298417..99738023 100644
--- a/test/test_mixins/test_scripting.py
+++ b/test/test_mixins/test_scripting.py
@@ -106,7 +106,8 @@ def test_script_help(r: redis.Redis):
@pytest.mark.min_server("7.1")
-def test_script_help71(r: redis.Redis):
+@pytest.mark.unsupported_server_types("valkey")
+def test_script_help73(r: redis.Redis):
assert raw_command(r, "SCRIPT HELP") == [
b"SCRIPT [ [value] [opt] ...]. Subcommands are:",
b"DEBUG (YES|SYNC|NO)",
@@ -341,10 +342,21 @@ def test_eval_call_bool6(r: redis.Redis):
@pytest.mark.min_server("7")
-def test_eval_call_bool7(r: redis.Redis):
+@pytest.mark.unsupported_server_types("valkey")
+def test_eval_call_bool7_redis(r: redis.Redis):
# Redis doesn't allow Lua bools to be passed to [p]call
- with pytest.raises(redis.ResponseError, match=r"Lua redis lib command arguments must be strings or integers"):
+ with pytest.raises(redis.ResponseError) as exc_info:
+ r.eval('return redis.call("SET", KEYS[1], true)', 1, "testkey")
+ assert "Lua redis lib command arguments must be strings or integers" in str(exc_info.value)
+
+
+@pytest.mark.min_server("7")
+@pytest.mark.unsupported_server_types("redis")
+def test_eval_call_bool7_valkey(r: redis.Redis):
+ # Redis doesn't allow Lua bools to be passed to [p]call
+ with pytest.raises(redis.ResponseError) as exc_info:
r.eval('return redis.call("SET", KEYS[1], true)', 1, "testkey")
+ assert "Command arguments must be strings or integers script" in str(exc_info.value)
def test_eval_return_error(r: redis.Redis):
diff --git a/test/test_sortedset_commands.py b/test/test_mixins/test_sortedset_commands.py
similarity index 100%
rename from test/test_sortedset_commands.py
rename to test/test_mixins/test_sortedset_commands.py
diff --git a/test/test_zadd.py b/test/test_mixins/test_zadd.py
similarity index 100%
rename from test/test_zadd.py
rename to test/test_mixins/test_zadd.py
diff --git a/test/test_stack/test_timeseries.py b/test/test_stack/test_timeseries.py
index e76ddea2..a7b0a89e 100644
--- a/test/test_stack/test_timeseries.py
+++ b/test/test_stack/test_timeseries.py
@@ -6,11 +6,12 @@
import redis
from fakeredis import _msgs as msgs
+from test.testtools import raw_command
timeseries_tests = pytest.importorskip("probables")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_add_ts_close(r: redis.Redis):
ts1 = r.ts().add(5, "*", 1)
time.sleep(0.001)
@@ -18,7 +19,44 @@ def test_add_ts_close(r: redis.Redis):
assert abs(ts2 - ts1) < 5
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.min_server("7")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
+def test_createrule_errors(r: redis.Redis):
+ timeseries = r.ts()
+ with pytest.raises(redis.ResponseError) as e:
+ timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10)
+ assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST
+
+ timeseries.create("t1")
+ with pytest.raises(redis.ResponseError) as e:
+ timeseries.createrule("t1", "t2", aggregation_type="sum", bucket_size_msec=10)
+ assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST
+
+ timeseries.create("t2")
+ with pytest.raises(redis.ResponseError) as e:
+ raw_command(r, "TS.CREATERULE", "t1", "t2", "AGGREGATION", "sum", 10, 1, 2)
+ assert str(e.value) in msgs.WRONG_ARGS_MSG6.format("ts.createrule")
+
+ with pytest.raises(redis.ResponseError) as e:
+ raw_command(r, "TS.CREATERULE", "t1", "t2", "AGGREGATION", "sum", 10, "20c")
+ assert str(e.value) == msgs.TIMESERIES_BAD_TIMESTAMP
+
+
+@pytest.mark.min_server("7")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
+def test_deleterule_errors(r: redis.Redis):
+ timeseries = r.ts()
+ with pytest.raises(redis.ResponseError) as e:
+ timeseries.deleterule("t1", "t2")
+ assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST
+
+ timeseries.create("t1")
+ with pytest.raises(redis.ResponseError) as e:
+ timeseries.deleterule("t1", "t2")
+ assert str(e.value) == msgs.TIMESERIES_RULE_DOES_NOT_EXIST
+
+
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_key_exist(r: redis.Redis):
assert r.ts().create(1)
with pytest.raises(redis.ResponseError) as e:
@@ -26,14 +64,14 @@ def test_create_key_exist(r: redis.Redis):
assert str(e.value) == msgs.TIMESERIES_KEY_EXISTS
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_bad_duplicate_policy(r: redis.Redis):
with pytest.raises(redis.ResponseError) as e:
assert r.ts().create(1, duplicate_policy="bad")
assert str(e.value) == msgs.TIMESERIES_INVALID_DUPLICATE_POLICY
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create(r: redis.Redis):
assert r.ts().create(1)
assert r.ts().create(2, retention_msecs=5)
@@ -49,7 +87,7 @@ def test_create(r: redis.Redis):
assert 128 == info.get("chunk_size")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_duplicate_policy(r: redis.Redis):
# Test for duplicate policy
for duplicate_policy in ["block", "last", "first", "min", "max"]:
@@ -59,7 +97,7 @@ def test_create_duplicate_policy(r: redis.Redis):
assert duplicate_policy == info.get("duplicate_policy")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_alter(r: redis.Redis):
assert r.ts().create(1)
info = r.ts().info(1)
@@ -80,7 +118,7 @@ def test_alter(r: redis.Redis):
assert str(e.value) == "TSDB: CHUNK_SIZE value must be a multiple of 8 in the range [48 .. 1048576]"
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_alter_diplicate_policy(r: redis.Redis):
assert r.ts().create(1)
info = r.ts().info(1)
@@ -91,7 +129,7 @@ def test_alter_diplicate_policy(r: redis.Redis):
assert "min" == info.get("duplicate_policy")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_add(r: redis.Redis):
assert 1 == r.ts().add(1, 1, 1)
assert 2 == r.ts().add(2, 2, 3, retention_msecs=10)
@@ -108,7 +146,7 @@ def test_add(r: redis.Redis):
assert 128 == info.get("chunk_size")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_add_before_retention(r: redis.Redis):
r.ts().create("time-serie-1", retention_msecs=1000)
assert r.ts().add("time-serie-1", 10000, 10.0)
@@ -117,7 +155,7 @@ def test_add_before_retention(r: redis.Redis):
assert str(e.value) == msgs.TIMESERIES_TIMESTAMP_OLDER_THAN_RETENTION
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_add_before_last(r: redis.Redis):
r.ts().create("time-serie-1", retention_msecs=1000)
assert r.ts().add("time-serie-1", 100, 10.0) == 100
@@ -132,7 +170,7 @@ def test_add_before_last(r: redis.Redis):
)
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_add_duplicate_policy(r: redis.Redis):
# Test for duplicate policy BLOCK
assert 1 == r.ts().add("time-serie-add-ooo-block", 1, 5.0)
@@ -161,13 +199,13 @@ def test_add_duplicate_policy(r: redis.Redis):
assert 5.0 == r.ts().get("time-serie-add-ooo-min")[1]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_madd(r: redis.Redis):
r.ts().create("a")
assert [1, 2, 3] == r.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)])
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_incrby_decrby(r: redis.Redis):
for _ in range(100):
assert r.ts().incrby(1, 1)
@@ -197,7 +235,7 @@ def test_incrby_decrby(r: redis.Redis):
assert 128 == info.get("chunk_size")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_and_delete_rule(r: redis.Redis):
# test rule creation
time = 100
@@ -222,7 +260,7 @@ def test_create_and_delete_rule(r: redis.Redis):
assert info2["source_key"] is None
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_del_range(r: redis.Redis):
with pytest.raises(redis.ResponseError) as e:
r.ts().delete("test", 0, 100)
@@ -237,7 +275,7 @@ def test_del_range(r: redis.Redis):
assert r.ts().delete(1, 60, 3) == 0
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_range(r: redis.Redis):
for i in range(100):
r.ts().add(1, i, i % 7)
@@ -254,7 +292,7 @@ def test_range(r: redis.Redis):
# assert 20 == len(r.ts().range(1, 0, 500, aggregation_type="avg", bucket_size_msec=10)) TODO
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_range_advanced(r: redis.Redis):
for i in range(100):
r.ts().add(1, i, i % 7)
@@ -277,7 +315,7 @@ def test_range_advanced(r: redis.Redis):
assert res == [(0, pytest.approx(2.55, 0.1)), (10, 3.0)]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_range_latest(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -295,7 +333,7 @@ def test_range_latest(r: redis.Redis):
assert timeseries.range("t2", 0, 9) == [(0, 4.0)]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_range_bucket_timestamp(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -320,7 +358,7 @@ def test_range_bucket_timestamp(r: redis.Redis):
) == [(20, 4.0), (60, 3.0), (80, 5.0)]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_range_empty(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -351,7 +389,7 @@ def test_range_empty(r: redis.Redis):
assert res == resp2_expected
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_rev_range(r: redis.Redis):
for i in range(100):
r.ts().add(1, i, i % 7)
@@ -379,8 +417,7 @@ def test_rev_range(r: redis.Redis):
]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_revrange_latest(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -396,7 +433,7 @@ def test_revrange_latest(r: redis.Redis):
assert timeseries.revrange("t2", 0, 9, latest=True) == [(0, 4.0)]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_revrange_bucket_timestamp(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -421,7 +458,7 @@ def test_revrange_bucket_timestamp(r: redis.Redis):
) == [(20, 4.0), (60, 3.0), (80, 5.0)]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_revrange_empty(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -450,8 +487,7 @@ def test_revrange_empty(r: redis.Redis):
]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_mrange(r: redis.Redis):
r.ts().create(1, labels={"Test": "This", "team": "ny"})
r.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"})
@@ -479,7 +515,7 @@ def test_mrange(r: redis.Redis):
assert {"Test": "This", "team": "ny"} == res[0]["1"][0]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_multi_range_advanced(r: redis.Redis):
r.ts().create(1, labels={"Test": "This", "team": "ny"})
r.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"})
@@ -535,8 +571,7 @@ def test_multi_range_advanced(r: redis.Redis):
# assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_mrange_latest(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -560,8 +595,7 @@ def test_mrange_latest(r: redis.Redis):
]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_multi_reverse_range(r: redis.Redis):
r.ts().create(1, labels={"Test": "This", "team": "ny"})
r.ts().create(2, labels={"Test": "This", "Taste": "That", "team": "sf"})
@@ -635,8 +669,7 @@ def test_multi_reverse_range(r: redis.Redis):
# assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_mrevrange_latest(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -660,7 +693,7 @@ def test_mrevrange_latest(r: redis.Redis):
]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_get(r: redis.Redis):
name = "test"
r.ts().create(name)
@@ -671,8 +704,7 @@ def test_get(r: redis.Redis):
assert 4 == r.ts().get(name)[1]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_get_latest(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -686,7 +718,7 @@ def test_get_latest(r: redis.Redis):
assert timeseries.get("t2", latest=True) == (0, 4.0)
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_mget_errors(r: redis.Redis):
r.ts().create(1, labels={"Test": "This"})
r.ts().create(2, labels={"Test": "This", "Taste": "That"})
@@ -699,7 +731,7 @@ def test_mget_errors(r: redis.Redis):
assert str(e.value) == "TSDB: failed parsing labels"
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_mget(r: redis.Redis):
r.ts().create(1, labels={"Test": "This"})
r.ts().create(2, labels={"Test": "This", "Taste": "That"})
@@ -721,8 +753,7 @@ def test_mget(r: redis.Redis):
assert {"Taste": "That", "Test": "This"} == res[0]["2"][0]
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_mget_latest(r: redis.Redis):
timeseries = r.ts()
timeseries.create("t1")
@@ -738,7 +769,7 @@ def test_mget_latest(r: redis.Redis):
assert res == [{"t2": [{}, 0, 4.0]}]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_info(r: redis.Redis):
r.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"})
info = r.ts().info(1)
@@ -746,7 +777,7 @@ def test_info(r: redis.Redis):
assert info["labels"]["currentLabel"] == "currentData"
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def testInfoDuplicatePolicy(r: redis.Redis):
r.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"})
info = r.ts().info(1)
@@ -757,8 +788,7 @@ def testInfoDuplicatePolicy(r: redis.Redis):
assert info.get("duplicate_policy") == "min"
-@pytest.mark.onlynoncluster
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_query_index(r: redis.Redis):
r.ts().create(1, labels={"Test": "This"})
r.ts().create(2, labels={"Test": "This", "Taste": "That"})
@@ -767,7 +797,7 @@ def test_query_index(r: redis.Redis):
assert r.ts().queryindex(["Taste=That"]) == [2]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_pipeline(r: redis.Redis):
pipeline = r.ts().pipeline()
pipeline.create("with_pipeline")
@@ -782,7 +812,7 @@ def test_pipeline(r: redis.Redis):
assert r.ts().get("with_pipeline")[1] == 99 * 1.1
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_uncompressed(r: redis.Redis):
r.ts().create("compressed")
r.ts().create("uncompressed", uncompressed=True)
@@ -792,7 +822,7 @@ def test_uncompressed(r: redis.Redis):
assert compressed_info["memory_usage"] != uncompressed_info["memory_usage"]
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_rule_green(r: redis.Redis):
r.ts().create(1)
r.ts().create(2)
@@ -808,7 +838,7 @@ def test_create_rule_green(r: redis.Redis):
assert info["source_key"] == b"1"
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_rule_bad_aggregator(r: redis.Redis):
r.ts().create(1)
r.ts().create(2)
@@ -817,14 +847,14 @@ def test_create_rule_bad_aggregator(r: redis.Redis):
assert str(e.value) == msgs.TIMESERIES_BAD_AGGREGATION_TYPE
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_rule_key_not_exist(r: redis.Redis):
with pytest.raises(redis.ResponseError) as e:
r.ts().createrule(1, 2, "avg", 100)
assert str(e.value) == msgs.TIMESERIES_KEY_DOES_NOT_EXIST
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_create_rule_with_rule_to_dest_key_exists(r: redis.Redis):
r.ts().create(1)
r.ts().create(2)
diff --git a/test/test_stack/test_topk.py b/test/test_stack/test_topk.py
index 26df13dd..9581df60 100644
--- a/test/test_stack/test_topk.py
+++ b/test/test_stack/test_topk.py
@@ -4,7 +4,7 @@
topk_tests = pytest.importorskip("probables")
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_topk_incrby(r: redis.Redis):
assert r.topk().reserve("topk", 3, 10, 3, 1)
assert [None, None, None] == r.topk().incrby("topk", ["bar", "baz", "42"], [3, 6, 2])
@@ -13,7 +13,7 @@ def test_topk_incrby(r: redis.Redis):
assert [3, 6, 10, 4, 0] == r.topk().count("topk", "bar", "baz", "42", "xyzzy", 4)
-@pytest.mark.unsupported_server_types("dragonfly")
+@pytest.mark.unsupported_server_types("dragonfly", "valkey")
def test_topk(r: redis.Redis):
# test list with empty buckets
assert r.topk().reserve("topk", 3, 50, 4, 0.9)
diff --git a/test/test_tcp_server/__init__.py b/test/test_tcp_server/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/test_tcp_server/test_connectivity.py b/test/test_tcp_server/test_connectivity.py
new file mode 100644
index 00000000..414ed587
--- /dev/null
+++ b/test/test_tcp_server/test_connectivity.py
@@ -0,0 +1,23 @@
+import sys
+import time
+from threading import Thread
+
+import pytest
+import redis
+
+from fakeredis import TcpFakeServer
+
+if sys.version_info < (3, 11):
+ pytest.skip("TcpFakeServer is only available in Python 3.11+", allow_module_level=True)
+
+
+def test_tcp_server_started():
+ server_address = ("127.0.0.1", 19000)
+ server = TcpFakeServer(server_address)
+ t = Thread(target=server.serve_forever, daemon=True)
+ t.start()
+ time.sleep(0.1)
+ r = redis.Redis(host=server_address[0], port=server_address[1])
+ r.set("foo", "bar")
+ assert r.get("foo") == b"bar"
+ server.shutdown()