From 0e56165d4a26001cff288dc11f728f6c353fc62b Mon Sep 17 00:00:00 2001 From: David Dougherty Date: Thu, 31 Aug 2023 02:00:58 -0700 Subject: [PATCH] DOC-2544: Adding new doctest to support updated VSS article (#2886) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Viktor Ivanov Co-authored-by: Sergey Prokazov Co-authored-by: Anuragkillswitch <70265851+Anuragkillswitch@users.noreply.github.com> Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> Co-authored-by: Alex Schmitz Co-authored-by: Alex Schmitz Co-authored-by: Chayim Co-authored-by: Bar Shaul <88437685+barshaul@users.noreply.github.com> Co-authored-by: CrimsonGlory Co-authored-by: Raymond Yin Co-authored-by: zach.lee Co-authored-by: James R T Co-authored-by: dvora-h Co-authored-by: Marc Schöchlin Co-authored-by: Nick Gerow Co-authored-by: Igor Malinovskiy Co-authored-by: Chayim I. Kirshen Co-authored-by: Leibale Eidelman Co-authored-by: Thiago Bellini Ribeiro Co-authored-by: woutdenolf Co-authored-by: shacharPash <93581407+shacharPash@users.noreply.github.com> Co-authored-by: Mirek Długosz Co-authored-by: Oran Avraham <252748+oranav@users.noreply.github.com> Co-authored-by: mzdehbashi-github <85902780+mzdehbashi-github@users.noreply.github.com> Co-authored-by: Tyler Hutcherson Co-authored-by: Felipe Machado <462154+felipou@users.noreply.github.com> Co-authored-by: AYMEN Mohammed <53928879+AYMENJD@users.noreply.github.com> Co-authored-by: Marc Schöchlin Co-authored-by: Avasam Co-authored-by: Markus Gerstel <2102431+Anthchirp@users.noreply.github.com> Co-authored-by: Kristján Valur Jónsson Co-authored-by: Nick Gerow Co-authored-by: Cristian Matache Co-authored-by: Anurag Bandyopadhyay Co-authored-by: Seongchuel Ahn Co-authored-by: Alibi Co-authored-by: Smit Parmar Co-authored-by: Brad MacPhee Co-authored-by: Shahar Lev Co-authored-by: Vladimir Mihailenco Co-authored-by: Kevin James Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: David Pacsuta <34983281+ant1fact@users.noreply.github.com> Co-authored-by: Rich Bowen Co-authored-by: gmbnomis Co-authored-by: Vivanov98 <66319645+Vivanov98@users.noreply.github.com> Co-authored-by: Kosuke Co-authored-by: Sergey Prokazov Co-authored-by: jmcbailey Co-authored-by: Galtozzy <14139502+Galtozzy@users.noreply.github.com> Co-authored-by: Abhishek Kumar Sinha Co-authored-by: Eom Taegyung "Iggy Co-authored-by: Mehdi ABAAKOUK Co-authored-by: Dongkeun Lee <3315213+zakaf@users.noreply.github.com> Co-authored-by: woutdenolf Co-authored-by: Kurt McKee Co-authored-by: Juraj Páll Co-authored-by: Joan Fontanals Co-authored-by: Stanislav Zmiev fix (#2566) Fix unlink in cluster pipeline (#2562) Fix issue 2540: Synchronise concurrent command calls to single-client mode. (#2568) Fix: tuple function cannot be passed more than one argument (#2573) Fix issue 2567: NoneType check before raising exception (#2569) Fix issue 2349: Let async HiredisParser finish parsing after a Connection.disconnect() (#2557) Fix issue with `pack_commands` returning an empty byte sequence (#2416) Fix #2581 UnixDomainSocketConnection' object has no attribute '_command_packer' (#2583) Fix #2581 UnixDomainSocketConnection' object has no attribute '_command_packer' . Fix for `lpop` and `rpop` return typing (#2590) Fixed CredentialsProvider examples (#2587) Fixed issue #2598 - make Document class subscriptable fix: replace async_timeout by asyncio.timeout (#2602) Fix behaviour of async PythonParser to match RedisParser as for issue #2349 (#2582) Fix (#2641) fix: do not use asyncio's timeout lib before 3.11.2 (#2659) Fix issue 2660: PytestUnraisableExceptionWarning from asycio client (#2669) Fixing cancelled async futures (#2666) Fix async (#2673) Fix memory leak caused by hiredis (#2693) (#2694) Fix incorrect usage of once flag in async Sentinel (#2718) Fix topk list example. (#2724) Fix `ClusterCommandProtocol` not itself being marked as a protocol (#2729) Fix potential race condition during disconnection (#2719) fix CI (#2748) fix parse_slowlog_get (#2732) fixes for issue #1128 fix create single_connection_client from url (#2752) Fix `xadd` allow non negative maxlen (#2739) Fix JSON.MERGE Summary (#2786) Fixed key error in parse_xinfo_stream (#2788) Fix dead weakref in sentinel connection causing ReferenceError (#2767) (#2771) Fix dead weakref in sentinel conn (#2767) fix redirects and some small cleanups (#2801) Fix type hint for retry_on_error in async cluster (#2804) Fix CI (#2809) Fix async client with resp3 (#2657) Fix `COMMAND` response in resp3 (redis 7+) (#2740) Fix protocol version checking (#2737) Fix parse resp3 dict response: don't use dict comprehension (#2757) Fixing asyncio import (#2759) fix (#2799) fix async tests (#2806) Fix socket garbage collection (#2859) Fixing doc builds (#2869) Fix a duplicate word in `CONTRIBUTING.md` (#2848) Fix timeout retrying on Redis pipeline execution (#2812) Fix type hints in SearchCommands (#2817) --- .flake8 | 26 + .github/dependabot.yml | 8 + .github/spellcheck-settings.yml | 29 + .github/wordlist.txt | 142 ++ .github/workflows/docs.yaml | 47 + .github/workflows/doctests.yml | 3 +- .github/workflows/integration.yaml | 85 +- .github/workflows/spellcheck.yml | 14 + .isort.cfg | 5 + CHANGES | 1088 +++++------ CONTRIBUTING.md | 24 +- LICENSE | 2 +- README.md | 30 +- benchmarks/socket_read_size.py | 4 +- dev_requirements.txt | 6 +- docker-compose.yml | 109 ++ docker/base/Dockerfile | 4 - docker/base/Dockerfile.cluster | 11 - docker/base/Dockerfile.cluster4 | 9 - docker/base/Dockerfile.cluster5 | 9 - docker/base/Dockerfile.redis4 | 4 - docker/base/Dockerfile.redis5 | 4 - docker/base/Dockerfile.redismod_cluster | 12 - docker/base/Dockerfile.sentinel | 4 - docker/base/Dockerfile.sentinel4 | 4 - docker/base/Dockerfile.sentinel5 | 4 - docker/base/Dockerfile.stunnel | 11 - docker/base/Dockerfile.unstable | 18 - docker/base/Dockerfile.unstable_cluster | 11 - docker/base/Dockerfile.unstable_sentinel | 17 - docker/base/README.md | 1 - docker/base/create_cluster4.sh | 26 - docker/base/create_cluster5.sh | 26 - docker/base/create_redismod_cluster.sh | 46 - docker/cluster/redis.conf | 3 - docker/redis4/master/redis.conf | 2 - docker/redis4/sentinel/sentinel_1.conf | 6 - docker/redis4/sentinel/sentinel_2.conf | 6 - docker/redis4/sentinel/sentinel_3.conf | 6 - docker/redis5/master/redis.conf | 2 - docker/redis5/replica/redis.conf | 3 - docker/redis5/sentinel/sentinel_1.conf | 6 - docker/redis5/sentinel/sentinel_2.conf | 6 - docker/redis5/sentinel/sentinel_3.conf | 6 - docker/redis6.2/master/redis.conf | 2 - docker/redis6.2/replica/redis.conf | 3 - docker/redis6.2/sentinel/sentinel_2.conf | 6 - docker/redis6.2/sentinel/sentinel_3.conf | 6 - docker/redis7/master/redis.conf | 4 - docker/redismod_cluster/redis.conf | 8 - docker/unstable/redis.conf | 3 - docker/unstable_cluster/redis.conf | 4 - dockers/Dockerfile.cluster | 7 + dockers/cluster.redis.conf | 8 + {docker/base => dockers}/create_cluster.sh | 7 +- .../sentinel_1.conf => dockers/sentinel.conf | 4 +- {docker => dockers}/stunnel/README | 0 {docker => dockers}/stunnel/conf/redis.conf | 2 +- {docker => dockers}/stunnel/create_certs.sh | 0 {docker => dockers}/stunnel/keys/ca-cert.pem | 0 {docker => dockers}/stunnel/keys/ca-key.pem | 0 .../stunnel/keys/client-cert.pem | 0 .../stunnel/keys/client-key.pem | 0 .../stunnel/keys/client-req.pem | 0 .../stunnel/keys/server-cert.pem | 0 .../stunnel/keys/server-key.pem | 0 .../stunnel/keys/server-req.pem | 0 docs/advanced_features.rst | 76 +- docs/clustering.rst | 20 +- docs/conf.py | 4 +- docs/examples/asyncio_examples.ipynb | 210 +- docs/examples/connection_examples.ipynb | 151 +- docs/examples/opentelemetry/main.py | 3 +- docs/examples/pipeline_examples.ipynb | 2 +- docs/examples/redis-stream-example.ipynb | 2 +- .../search_vector_similarity_examples.ipynb | 610 +++++- docs/examples/ssl_connection_examples.ipynb | 21 + docs/examples/timeseries_examples.ipynb | 41 + docs/index.rst | 9 +- docs/lua_scripting.rst | 8 +- docs/opentelemetry.rst | 73 +- docs/redismodules.rst | 10 +- docs/requirements.txt | 1 + docs/resp3_features.rst | 69 + doctests/README.md | 11 +- doctests/requirements.txt | 5 + doctests/search_quickstart.py | 12 +- doctests/search_vss.py | 308 +++ pytest.ini | 13 + redis/__init__.py | 5 +- redis/_parsers/__init__.py | 20 + redis/_parsers/base.py | 232 +++ .../parser.py => _parsers/commands.py} | 201 +- redis/_parsers/encoders.py | 44 + redis/_parsers/helpers.py | 852 +++++++++ redis/_parsers/hiredis.py | 217 +++ redis/_parsers/resp2.py | 132 ++ redis/_parsers/resp3.py | 261 +++ redis/_parsers/socket.py | 162 ++ redis/asyncio/__init__.py | 4 +- redis/asyncio/client.py | 104 +- redis/asyncio/cluster.py | 64 +- redis/asyncio/connection.py | 831 +++----- redis/asyncio/parser.py | 94 - redis/asyncio/sentinel.py | 17 +- redis/client.py | 924 ++------- redis/cluster.py | 190 +- redis/commands/__init__.py | 2 - redis/commands/bf/__init__.py | 85 +- redis/commands/bf/commands.py | 3 - redis/commands/bf/info.py | 33 + redis/commands/cluster.py | 19 +- redis/commands/core.py | 342 +++- redis/commands/helpers.py | 8 + redis/commands/json/__init__.py | 44 +- redis/commands/json/commands.py | 56 +- redis/commands/search/__init__.py | 27 +- redis/commands/search/commands.py | 219 +-- redis/commands/search/document.py | 4 + redis/commands/timeseries/__init__.py | 38 +- redis/commands/timeseries/info.py | 9 + redis/compat.py | 7 +- redis/connection.py | 990 ++++------ redis/exceptions.py | 13 + redis/ocsp.py | 1 - redis/py.typed | 0 redis/sentinel.py | 108 +- redis/typing.py | 21 +- redis/utils.py | 37 + setup.py | 8 +- tasks.py | 66 +- tests/conftest.py | 95 +- tests/ssl_utils.py | 14 + tests/test_asyncio/conftest.py | 52 +- tests/test_asyncio/test_bloom.py | 514 ++--- tests/test_asyncio/test_cluster.py | 503 +++-- tests/test_asyncio/test_commands.py | 673 +++++-- tests/test_asyncio/test_connect.py | 144 ++ tests/test_asyncio/test_connection.py | 211 +- tests/test_asyncio/test_connection_pool.py | 30 +- tests/test_asyncio/test_credentials.py | 1 - tests/test_asyncio/test_cwe_404.py | 250 +++ tests/test_asyncio/test_encoding.py | 2 +- tests/test_asyncio/test_graph.py | 132 +- tests/test_asyncio/test_json.py | 752 ++++---- tests/test_asyncio/test_lock.py | 1 - tests/test_asyncio/test_monitor.py | 1 - tests/test_asyncio/test_pipeline.py | 3 - tests/test_asyncio/test_pubsub.py | 66 +- tests/test_asyncio/test_retry.py | 1 - tests/test_asyncio/test_scripting.py | 1 - tests/test_asyncio/test_search.py | 1562 ++++++++++----- tests/test_asyncio/test_sentinel.py | 3 +- .../test_sentinel_managed_connection.py | 1 - tests/test_asyncio/test_timeseries.py | 870 +++++---- tests/test_bloom.py | 131 +- tests/test_cluster.py | 566 +++++- tests/test_command_parser.py | 50 +- tests/test_commands.py | 916 ++++++--- tests/test_connect.py | 191 ++ tests/test_connection.py | 70 +- tests/test_connection_pool.py | 15 +- tests/test_credentials.py | 7 +- tests/test_encoding.py | 6 +- tests/test_function.py | 46 +- tests/test_graph.py | 34 +- tests/test_graph_utils/test_edge.py | 1 - tests/test_graph_utils/test_node.py | 1 - tests/test_graph_utils/test_path.py | 1 - tests/test_json.py | 410 ++-- tests/test_lock.py | 7 +- tests/test_multiprocessing.py | 1 - tests/test_pipeline.py | 3 - tests/test_pubsub.py | 424 ++++- tests/test_retry.py | 1 - tests/test_scripting.py | 1 - tests/test_search.py | 1695 +++++++++++------ tests/test_sentinel.py | 10 +- tests/test_ssl.py | 16 +- tests/test_timeseries.py | 628 ++++-- tox.ini | 379 ---- whitelist.py | 1 - 182 files changed, 14214 insertions(+), 7385 deletions(-) create mode 100644 .flake8 create mode 100644 .github/dependabot.yml create mode 100644 .github/spellcheck-settings.yml create mode 100644 .github/wordlist.txt create mode 100644 .github/workflows/docs.yaml create mode 100644 .github/workflows/spellcheck.yml create mode 100644 .isort.cfg create mode 100644 docker-compose.yml delete mode 100644 docker/base/Dockerfile delete mode 100644 docker/base/Dockerfile.cluster delete mode 100644 docker/base/Dockerfile.cluster4 delete mode 100644 docker/base/Dockerfile.cluster5 delete mode 100644 docker/base/Dockerfile.redis4 delete mode 100644 docker/base/Dockerfile.redis5 delete mode 100644 docker/base/Dockerfile.redismod_cluster delete mode 100644 docker/base/Dockerfile.sentinel delete mode 100644 docker/base/Dockerfile.sentinel4 delete mode 100644 docker/base/Dockerfile.sentinel5 delete mode 100644 docker/base/Dockerfile.stunnel delete mode 100644 docker/base/Dockerfile.unstable delete mode 100644 docker/base/Dockerfile.unstable_cluster delete mode 100644 docker/base/Dockerfile.unstable_sentinel delete mode 100644 docker/base/README.md delete mode 100755 docker/base/create_cluster4.sh delete mode 100755 docker/base/create_cluster5.sh delete mode 100755 docker/base/create_redismod_cluster.sh delete mode 100644 docker/cluster/redis.conf delete mode 100644 docker/redis4/master/redis.conf delete mode 100644 docker/redis4/sentinel/sentinel_1.conf delete mode 100644 docker/redis4/sentinel/sentinel_2.conf delete mode 100644 docker/redis4/sentinel/sentinel_3.conf delete mode 100644 docker/redis5/master/redis.conf delete mode 100644 docker/redis5/replica/redis.conf delete mode 100644 docker/redis5/sentinel/sentinel_1.conf delete mode 100644 docker/redis5/sentinel/sentinel_2.conf delete mode 100644 docker/redis5/sentinel/sentinel_3.conf delete mode 100644 docker/redis6.2/master/redis.conf delete mode 100644 docker/redis6.2/replica/redis.conf delete mode 100644 docker/redis6.2/sentinel/sentinel_2.conf delete mode 100644 docker/redis6.2/sentinel/sentinel_3.conf delete mode 100644 docker/redis7/master/redis.conf delete mode 100644 docker/redismod_cluster/redis.conf delete mode 100644 docker/unstable/redis.conf delete mode 100644 docker/unstable_cluster/redis.conf create mode 100644 dockers/Dockerfile.cluster create mode 100644 dockers/cluster.redis.conf rename {docker/base => dockers}/create_cluster.sh (75%) mode change 100755 => 100644 rename docker/redis6.2/sentinel/sentinel_1.conf => dockers/sentinel.conf (73%) rename {docker => dockers}/stunnel/README (100%) rename {docker => dockers}/stunnel/conf/redis.conf (83%) rename {docker => dockers}/stunnel/create_certs.sh (100%) rename {docker => dockers}/stunnel/keys/ca-cert.pem (100%) rename {docker => dockers}/stunnel/keys/ca-key.pem (100%) rename {docker => dockers}/stunnel/keys/client-cert.pem (100%) rename {docker => dockers}/stunnel/keys/client-key.pem (100%) rename {docker => dockers}/stunnel/keys/client-req.pem (100%) rename {docker => dockers}/stunnel/keys/server-cert.pem (100%) rename {docker => dockers}/stunnel/keys/server-key.pem (100%) rename {docker => dockers}/stunnel/keys/server-req.pem (100%) create mode 100644 docs/resp3_features.rst create mode 100644 doctests/requirements.txt create mode 100644 doctests/search_vss.py create mode 100644 pytest.ini create mode 100644 redis/_parsers/__init__.py create mode 100644 redis/_parsers/base.py rename redis/{commands/parser.py => _parsers/commands.py} (56%) create mode 100644 redis/_parsers/encoders.py create mode 100644 redis/_parsers/helpers.py create mode 100644 redis/_parsers/hiredis.py create mode 100644 redis/_parsers/resp2.py create mode 100644 redis/_parsers/resp3.py create mode 100644 redis/_parsers/socket.py delete mode 100644 redis/asyncio/parser.py mode change 100755 => 100644 redis/connection.py create mode 100644 redis/py.typed create mode 100644 tests/ssl_utils.py create mode 100644 tests/test_asyncio/test_connect.py create mode 100644 tests/test_asyncio/test_cwe_404.py create mode 100644 tests/test_connect.py delete mode 100644 tox.ini diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..6c663473e4 --- /dev/null +++ b/.flake8 @@ -0,0 +1,26 @@ +[flake8] +max-line-length = 88 +exclude = + *.egg-info, + *.pyc, + .git, + .tox, + .venv*, + build, + docs/*, + dist, + docker, + venv*, + .venv*, + whitelist.py, + tasks.py +ignore = + E126 + E203 + F405 + N801 + N802 + N803 + N806 + N815 + W503 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..ac71d74297 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + labels: + - "maintenance" + schedule: + interval: "monthly" diff --git a/.github/spellcheck-settings.yml b/.github/spellcheck-settings.yml new file mode 100644 index 0000000000..96abbe6da8 --- /dev/null +++ b/.github/spellcheck-settings.yml @@ -0,0 +1,29 @@ +matrix: +- name: Markdown + expect_match: false + apsell: + lang: en + d: en_US + ignore-case: true + dictionary: + wordlists: + - .github/wordlist.txt + output: wordlist.dic + pipeline: + - pyspelling.filters.markdown: + markdown_extensions: + - markdown.extensions.extra: + - pyspelling.filters.html: + comments: false + attributes: + - alt + ignores: + - ':matches(code, pre)' + - code + - pre + - blockquote + - img + sources: + - '*.md' + - 'docs/*.rst' + - 'docs/*.ipynb' diff --git a/.github/wordlist.txt b/.github/wordlist.txt new file mode 100644 index 0000000000..be16c437ff --- /dev/null +++ b/.github/wordlist.txt @@ -0,0 +1,142 @@ +APM +ARGV +BFCommands +CFCommands +CMSCommands +ClusterNode +ClusterNodes +ClusterPipeline +ClusterPubSub +ConnectionPool +CoreCommands +EVAL +EVALSHA +GraphCommands +Grokzen's +INCR +IOError +Instrumentations +JSONCommands +Jaeger +Ludovico +Magnocavallo +McCurdy +NOSCRIPT +NUMPAT +NUMPT +NUMSUB +OSS +OpenCensus +OpenTelemetry +OpenTracing +Otel +PubSub +READONLY +RediSearch +RedisBloom +RedisCluster +RedisClusterCommands +RedisClusterException +RedisClusters +RedisGraph +RedisInstrumentor +RedisJSON +RedisTimeSeries +SHA +SearchCommands +SentinelCommands +SentinelConnectionPool +Sharded +Solovyov +SpanKind +Specfiying +StatusCode +TCP +TOPKCommands +TimeSeriesCommands +Uptrace +ValueError +WATCHed +WatchError +api +args +async +asyncio +autoclass +automodule +backoff +bdb +behaviour +bool +boolean +booleans +bysource +charset +del +dev +eg +exc +firsttimersonly +fo +genindex +gmail +hiredis +http +idx +iff +ini +json +keyslot +keyspace +kwarg +linters +localhost +lua +makeapullrequest +maxdepth +mget +microservice +microservices +mset +multikey +mykey +nonatomic +observability +opentelemetry +oss +performant +pmessage +png +pre +psubscribe +pubsub +punsubscribe +py +pypi +quickstart +readonly +readwrite +redis +redismodules +reinitialization +replicaof +repo +runtime +sedrik +sharded +ssl +str +stunnel +subcommands +thevalueofmykey +timeseries +toctree +topk +tox +triaging +txt +un +unicode +url +virtualenv +www diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml new file mode 100644 index 0000000000..61ec76e9f8 --- /dev/null +++ b/.github/workflows/docs.yaml @@ -0,0 +1,47 @@ +name: Docs CI + +on: + push: + branches: + - master + - '[0-9].[0-9]' + pull_request: + branches: + - master + - '[0-9].[0-9]' + schedule: + - cron: '0 1 * * *' # nightly build + +concurrency: + group: ${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +permissions: + contents: read # to fetch code (actions/checkout) + +jobs: + + build-docs: + name: Build docs + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.9 + cache: 'pip' + - name: install deps + run: | + sudo apt-get update -yqq + sudo apt-get install -yqq pandoc make + - name: run code linters + run: | + pip install -r requirements.txt -r dev_requirements.txt -r docs/requirements.txt + invoke build-docs + + - name: upload docs + uses: actions/upload-artifact@v3 + with: + name: redis-py-docs + path: | + docs/_build/html diff --git a/.github/workflows/doctests.yml b/.github/workflows/doctests.yml index 3e2537f961..49254b0766 100644 --- a/.github/workflows/doctests.yml +++ b/.github/workflows/doctests.yml @@ -32,7 +32,8 @@ jobs: run: | pip install -r dev_requirements.txt pip install -r requirements.txt - isort --check-only --diff doctests/*.py + pip install -r doctests/requirements.txt + isort -l 80 --profile black --check-only --diff doctests/*.py black -l 80 --target-version py39 --check --diff doctests/*.py - name: run tests diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 8d38cd45c7..7e0fea2e41 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -16,6 +16,10 @@ on: schedule: - cron: '0 1 * * *' # nightly build +concurrency: + group: ${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + permissions: contents: read # to fetch code (actions/checkout) @@ -26,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: pypa/gh-action-pip-audit@v1.0.0 + - uses: pypa/gh-action-pip-audit@v1.0.8 with: inputs: requirements.txt dev_requirements.txt ignore-vulns: | @@ -48,11 +52,12 @@ jobs: run-tests: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: max-parallel: 15 + fail-fast: false matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] test-type: ['standalone', 'cluster'] connection-type: ['hiredis', 'plain'] env: @@ -67,32 +72,77 @@ jobs: - name: run tests run: | pip install -U setuptools wheel + pip install -r requirements.txt pip install -r dev_requirements.txt - tox -e ${{matrix.test-type}}-${{matrix.connection-type}} - - uses: actions/upload-artifact@v2 + if [ "${{matrix.connection-type}}" == "hiredis" ]; then + pip install hiredis + fi + invoke devenv + sleep 10 # time to settle + invoke ${{matrix.test-type}}-tests + + - uses: actions/upload-artifact@v3 if: success() || failure() with: - name: pytest-results-${{matrix.test-type}} + name: pytest-results-${{matrix.test-type}}-${{matrix.connection-type}}-${{matrix.python-version}} path: '${{matrix.test-type}}*results.xml' + - name: Upload codecov coverage uses: codecov/codecov-action@v3 + if: ${{matrix.python-version == '3.11'}} with: fail_ci_if_error: false - # - name: View Test Results - # uses: dorny/test-reporter@v1 - # if: success() || failure() - # with: - # name: Test Results ${{matrix.python-version}} ${{matrix.test-type}}-${{matrix.connection-type}} - # path: '${{matrix.test-type}}*results.xml' - # reporter: java-junit - # list-suites: failed - # list-tests: failed - # max-annotations: 10 + + - name: View Test Results + uses: dorny/test-reporter@v1 + if: success() || failure() + continue-on-error: true + with: + name: Test Results ${{matrix.python-version}} ${{matrix.test-type}}-${{matrix.connection-type}} + path: '*.xml' + reporter: java-junit + list-suites: all + list-tests: all + max-annotations: 10 + fail-on-error: 'false' + + resp3_tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.7', '3.11'] + test-type: ['standalone', 'cluster'] + connection-type: ['hiredis', 'plain'] + protocol: ['3'] + env: + ACTIONS_ALLOW_UNSECURE_COMMANDS: true + name: RESP3 [${{ matrix.python-version }} ${{matrix.test-type}}-${{matrix.connection-type}}] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: run tests + run: | + pip install -U setuptools wheel + pip install -r requirements.txt + pip install -r dev_requirements.txt + if [ "${{matrix.connection-type}}" == "hiredis" ]; then + pip install hiredis + fi + invoke devenv + sleep 5 # time to settle + invoke ${{matrix.test-type}}-tests + invoke ${{matrix.test-type}}-tests --uvloop build_and_test_package: name: Validate building and installing the package runs-on: ubuntu-latest + needs: [run-tests] strategy: + fail-fast: false matrix: extension: ['tar.gz', 'whl'] steps: @@ -108,8 +158,9 @@ jobs: name: Install package from commit hash runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/.github/workflows/spellcheck.yml b/.github/workflows/spellcheck.yml new file mode 100644 index 0000000000..e152841553 --- /dev/null +++ b/.github/workflows/spellcheck.yml @@ -0,0 +1,14 @@ +name: spellcheck +on: + pull_request: +jobs: + check-spelling: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Check Spelling + uses: rojopolis/spellcheck-github-actions@0.33.1 + with: + config_path: .github/spellcheck-settings.yml + task_name: Markdown diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000000..039f0337a2 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +profile=black +multi_line_output=3 +src_paths = ["redis", "tests"] +skip_glob=benchmarks/* \ No newline at end of file diff --git a/CHANGES b/CHANGES index 02daf5ee4c..7b3b4c5ac2 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,21 @@ + * Fix #2831, add auto_close_connection_pool=True arg to asyncio.Redis.from_url() + * Fix incorrect redis.asyncio.Cluster type hint for `retry_on_error` + * Fix dead weakref in sentinel connection causing ReferenceError (#2767) + * Fix #2768, Fix KeyError: 'first-entry' in parse_xinfo_stream. + * Fix #2749, remove unnecessary __del__ logic to close connections. + * Fix #2754, adding a missing argument to SentinelManagedConnection + * Fix `xadd` command to accept non-negative `maxlen` including 0 + * Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624) + * Add `address_remap` parameter to `RedisCluster` + * Fix incorrect usage of once flag in async Sentinel + * asyncio: Fix memory leak caused by hiredis (#2693) + * Allow data to drain from async PythonParser when reading during a disconnect() + * Use asyncio.timeout() instead of async_timeout.timeout() for python >= 3.11 (#2602) + * Add a Dependabot configuration to auto-update GitHub action versions. + * Add test and fix async HiredisParser when reading during a disconnect() (#2349) + * Use hiredis-py pack_command if available. + * Support `.unlink()` in ClusterPipeline + * Simplify synchronous SocketBuffer state management * Fix string cleanse in Redis Graph * Make PythonParser resumable in case of error (#2510) * Add `timeout=None` in `SentinelConnectionManager.read_response` @@ -34,323 +52,325 @@ * Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458) * Added a replacement for the default cluster node in the event of failure (#2463) * Fix for Unhandled exception related to self.host with unix socket (#2496) + * Improve error output for master discovery + * Make `ClusterCommandsProtocol` an actual Protocol * 4.1.3 (Feb 8, 2022) - * Fix flushdb and flushall (#1926) - * Add redis5 and redis4 dockers (#1871) - * Change json.clear test multi to be up to date with redisjson (#1922) - * Fixing volume for unstable_cluster docker (#1914) - * Update changes file with changes since 4.0.0-beta2 (#1915) + * Fix flushdb and flushall (#1926) + * Add redis5 and redis4 dockers (#1871) + * Change json.clear test multi to be up to date with redisjson (#1922) + * Fixing volume for unstable_cluster docker (#1914) + * Update changes file with changes since 4.0.0-beta2 (#1915) * 4.1.2 (Jan 27, 2022) - * Invalid OCSP certificates should raise ConnectionError on failed validation (#1907) - * Added retry mechanism on socket timeouts when connecting to the server (#1895) - * LMOVE, BLMOVE return incorrect responses (#1906) - * Fixing AttributeError in UnixDomainSocketConnection (#1903) - * Fixing TypeError in GraphCommands.explain (#1901) - * For tests, increasing wait time for the cluster (#1908) - * Increased pubsub's wait_for_messages timeout to prevent flaky tests (#1893) - * README code snippets formatted to highlight properly (#1888) - * Fix link in the main page (#1897) - * Documentation fixes: JSON Example, SSL Connection Examples, RTD version (#1887) - * Direct link to readthedocs (#1885) + * Invalid OCSP certificates should raise ConnectionError on failed validation (#1907) + * Added retry mechanism on socket timeouts when connecting to the server (#1895) + * LMOVE, BLMOVE return incorrect responses (#1906) + * Fixing AttributeError in UnixDomainSocketConnection (#1903) + * Fixing TypeError in GraphCommands.explain (#1901) + * For tests, increasing wait time for the cluster (#1908) + * Increased pubsub's wait_for_messages timeout to prevent flaky tests (#1893) + * README code snippets formatted to highlight properly (#1888) + * Fix link in the main page (#1897) + * Documentation fixes: JSON Example, SSL Connection Examples, RTD version (#1887) + * Direct link to readthedocs (#1885) * 4.1.1 (Jan 17, 2022) - * Add retries to connections in Sentinel Pools (#1879) - * OCSP Stapling Support (#1873) - * Define incr/decr as aliases of incrby/decrby (#1874) - * FT.CREATE - support MAXTEXTFIELDS, TEMPORARY, NOHL, NOFREQS, SKIPINITIALSCAN (#1847) - * Timeseries docs fix (#1877) - * get_connection: catch OSError too (#1832) - * Set keys var otherwise variable not created (#1853) - * Clusters should optionally require full slot coverage (#1845) - * Triple quote docstrings in client.py PEP 257 (#1876) - * syncing requirements (#1870) - * Typo and typing in GraphCommands documentation (#1855) - * Allowing poetry and redis-py to install together (#1854) - * setup.py: Add project_urls for PyPI (#1867) - * Support test with redis unstable docker (#1850) - * Connection examples (#1835) - * Documentation cleanup (#1841) + * Add retries to connections in Sentinel Pools (#1879) + * OCSP Stapling Support (#1873) + * Define incr/decr as aliases of incrby/decrby (#1874) + * FT.CREATE - support MAXTEXTFIELDS, TEMPORARY, NOHL, NOFREQS, SKIPINITIALSCAN (#1847) + * Timeseries docs fix (#1877) + * get_connection: catch OSError too (#1832) + * Set keys var otherwise variable not created (#1853) + * Clusters should optionally require full slot coverage (#1845) + * Triple quote docstrings in client.py PEP 257 (#1876) + * syncing requirements (#1870) + * Typo and typing in GraphCommands documentation (#1855) + * Allowing poetry and redis-py to install together (#1854) + * setup.py: Add project_urls for PyPI (#1867) + * Support test with redis unstable docker (#1850) + * Connection examples (#1835) + * Documentation cleanup (#1841) * 4.1.0 (Dec 26, 2021) - * OCSP stapling support (#1820) - * Support for SELECT (#1825) - * Support for specifying error types with retry (#1817) - * Support for RESET command since Redis 6.2.0 (#1824) - * Support CLIENT TRACKING (#1612) - * Support WRITE in CLIENT PAUSE (#1549) - * JSON set_file and set_path support (#1818) - * Allow ssl_ca_path with rediss:// urls (#1814) - * Support for password-encrypted SSL private keys (#1782) - * Support SYNC and PSYNC (#1741) - * Retry on error exception and timeout fixes (#1821) - * Fixing read race condition during pubsub (#1737) - * Fixing exception in listen (#1823) - * Fixed MovedError, and stopped iterating through startup nodes when slots are fully covered (#1819) - * Socket not closing after server disconnect (#1797) - * Single sourcing the package version (#1791) - * Ensure redis_connect_func is set on uds connection (#1794) - * SRTALGO - Skip for redis versions greater than 7.0.0 (#1831) - * Documentation updates (#1822) - * Add CI action to install package from repository commit hash (#1781) (#1790) - * Fix link in lmove docstring (#1793) - * Disabling JSON.DEBUG tests (#1787) - * Migrated targeted nodes to kwargs in Cluster Mode (#1762) - * Added support for MONITOR in clusters (#1756) - * Adding ROLE Command (#1610) - * Integrate RedisBloom support (#1683) - * Adding RedisGraph support (#1556) - * Allow overriding connection class via keyword arguments (#1752) - * Aggregation LOAD * support for RediSearch (#1735) - * Adding cluster, bloom, and graph docs (#1779) - * Add packaging to setup_requires, and use >= to play nice to setup.py (fixes #1625) (#1780) - * Fixing the license link in the readme (#1778) - * Removing distutils from tests (#1773) - * Fix cluster ACL tests (#1774) - * Improved RedisCluster's reinitialize_steps and documentation (#1765) - * Added black and isort (#1734) - * Link Documents for all module commands (#1711) - * Pyupgrade + flynt + f-strings (#1759) - * Remove unused aggregation subclasses in RediSearch (#1754) - * Adding RedisCluster client to support Redis Cluster Mode (#1660) - * Support RediSearch FT.PROFILE command (#1727) - * Adding support for non-decodable commands (#1731) - * COMMAND GETKEYS support (#1738) - * RedisJSON 2.0.4 behaviour support (#1747) - * Removing deprecating distutils (PEP 632) (#1730) - * Updating PR template (#1745) - * Removing duplication of Script class (#1751) - * Splitting documentation for read the docs (#1743) - * Improve code coverage for aggregation tests (#1713) - * Fixing COMMAND GETKEYS tests (#1750) - * GitHub release improvements (#1684) + * OCSP stapling support (#1820) + * Support for SELECT (#1825) + * Support for specifying error types with retry (#1817) + * Support for RESET command since Redis 6.2.0 (#1824) + * Support CLIENT TRACKING (#1612) + * Support WRITE in CLIENT PAUSE (#1549) + * JSON set_file and set_path support (#1818) + * Allow ssl_ca_path with rediss:// urls (#1814) + * Support for password-encrypted SSL private keys (#1782) + * Support SYNC and PSYNC (#1741) + * Retry on error exception and timeout fixes (#1821) + * Fixing read race condition during pubsub (#1737) + * Fixing exception in listen (#1823) + * Fixed MovedError, and stopped iterating through startup nodes when slots are fully covered (#1819) + * Socket not closing after server disconnect (#1797) + * Single sourcing the package version (#1791) + * Ensure redis_connect_func is set on uds connection (#1794) + * SRTALGO - Skip for redis versions greater than 7.0.0 (#1831) + * Documentation updates (#1822) + * Add CI action to install package from repository commit hash (#1781) (#1790) + * Fix link in lmove docstring (#1793) + * Disabling JSON.DEBUG tests (#1787) + * Migrated targeted nodes to kwargs in Cluster Mode (#1762) + * Added support for MONITOR in clusters (#1756) + * Adding ROLE Command (#1610) + * Integrate RedisBloom support (#1683) + * Adding RedisGraph support (#1556) + * Allow overriding connection class via keyword arguments (#1752) + * Aggregation LOAD * support for RediSearch (#1735) + * Adding cluster, bloom, and graph docs (#1779) + * Add packaging to setup_requires, and use >= to play nice to setup.py (fixes #1625) (#1780) + * Fixing the license link in the readme (#1778) + * Removing distutils from tests (#1773) + * Fix cluster ACL tests (#1774) + * Improved RedisCluster's reinitialize_steps and documentation (#1765) + * Added black and isort (#1734) + * Link Documents for all module commands (#1711) + * Pyupgrade + flynt + f-strings (#1759) + * Remove unused aggregation subclasses in RediSearch (#1754) + * Adding RedisCluster client to support Redis Cluster Mode (#1660) + * Support RediSearch FT.PROFILE command (#1727) + * Adding support for non-decodable commands (#1731) + * COMMAND GETKEYS support (#1738) + * RedisJSON 2.0.4 behaviour support (#1747) + * Removing deprecating distutils (PEP 632) (#1730) + * Updating PR template (#1745) + * Removing duplication of Script class (#1751) + * Splitting documentation for read the docs (#1743) + * Improve code coverage for aggregation tests (#1713) + * Fixing COMMAND GETKEYS tests (#1750) + * GitHub release improvements (#1684) * 4.0.2 (Nov 22, 2021) - * Restoring Sentinel commands to redis client (#1723) - * Better removal of hiredis warning (#1726) - * Adding links to redis documents in function calls (#1719) + * Restoring Sentinel commands to redis client (#1723) + * Better removal of hiredis warning (#1726) + * Adding links to redis documents in function calls (#1719) * 4.0.1 (Nov 17, 2021) - * Removing command on initial connections (#1722) - * Removing hiredis warning when not installed (#1721) + * Removing command on initial connections (#1722) + * Removing hiredis warning when not installed (#1721) * 4.0.0 (Nov 15, 2021) - * FT.EXPLAINCLI intentionally raising NotImplementedError - * Restoring ZRANGE desc for Redis < 6.2.0 (#1697) - * Response parsing occasionally fails to parse floats (#1692) - * Re-enabling read-the-docs (#1707) - * Call HSET after FT.CREATE to avoid keyspace scan (#1706) - * Unit tests fixes for compatibility (#1703) - * Improve documentation about Locks (#1701) - * Fixes to allow --redis-url to pass through all tests (#1700) - * Fix unit tests running against Redis 4.0.0 (#1699) - * Search alias test fix (#1695) - * Adding RediSearch/RedisJSON tests (#1691) - * Updating codecov rules (#1689) - * Tests to validate custom JSON decoders (#1681) - * Added breaking icon to release drafter (#1702) - * Removing dependency on six (#1676) - * Re-enable pipeline support for JSON and TimeSeries (#1674) - * Export Sentinel, and SSL like other classes (#1671) - * Restore zrange functionality for older versions of Redis (#1670) - * Fixed garbage collection deadlock (#1578) - * Tests to validate built python packages (#1678) - * Sleep for flaky search test (#1680) - * Test function renames, to match standards (#1679) - * Docstring improvements for Redis class (#1675) - * Fix georadius tests (#1672) - * Improvements to JSON coverage (#1666) - * Add python_requires setuptools check for python > 3.6 (#1656) - * SMISMEMBER support (#1667) - * Exposing the module version in loaded_modules (#1648) - * RedisTimeSeries support (#1652) - * Support for json multipath ($) (#1663) - * Added boolean parsing to PEXPIRE and PEXPIREAT (#1665) - * Add python_requires setuptools check for python > 3.6 (#1656) - * Adding vulture for static analysis (#1655) - * Starting to clean the docs (#1657) - * Update README.md (#1654) - * Adding description format for package (#1651) - * Publish to pypi as releases are generated with the release drafter (#1647) - * Restore actions to prs (#1653) - * Fixing the package to include commands (#1649) - * Re-enabling codecov as part of CI process (#1646) - * Adding support for redisearch (#1640) Thanks @chayim - * redisjson support (#1636) Thanks @chayim - * Sentinel: Add SentinelManagedSSLConnection (#1419) Thanks @AbdealiJK - * Enable floating parameters in SET (ex and px) (#1635) Thanks @AvitalFineRedis - * Add warning when hiredis not installed. Recommend installation. (#1621) Thanks @adiamzn - * Raising NotImplementedError for SCRIPT DEBUG and DEBUG SEGFAULT (#1624) Thanks @chayim - * CLIENT REDIR command support (#1623) Thanks @chayim - * REPLICAOF command implementation (#1622) Thanks @chayim - * Add support to NX XX and CH to GEOADD (#1605) Thanks @AvitalFineRedis - * Add support to ZRANGE and ZRANGESTORE parameters (#1603) Thanks @AvitalFineRedis - * Pre 6.2 redis should default to None for script flush (#1641) Thanks @chayim - * Add FULL option to XINFO SUMMARY (#1638) Thanks @agusdmb - * Geosearch test should use any=True (#1594) Thanks @Andrew-Chen-Wang - * Removing packaging dependency (#1626) Thanks @chayim - * Fix client_kill_filter docs for skimpy (#1596) Thanks @Andrew-Chen-Wang - * Normalize minid and maxlen docs (#1593) Thanks @Andrew-Chen-Wang - * Update docs for multiple usernames for ACL DELUSER (#1595) Thanks @Andrew-Chen-Wang - * Fix grammar of get param in set command (#1588) Thanks @Andrew-Chen-Wang - * Fix docs for client_kill_filter (#1584) Thanks @Andrew-Chen-Wang - * Convert README & CONTRIBUTING from rst to md (#1633) Thanks @davidylee - * Test BYLEX param in zrangestore (#1634) Thanks @AvitalFineRedis - * Tox integrations with invoke and docker (#1632) Thanks @chayim - * Adding the release drafter to help simplify release notes (#1618). Thanks @chayim - * BACKWARDS INCOMPATIBLE: Removed support for end of life Python 2.7. #1318 - * BACKWARDS INCOMPATIBLE: All values within Redis URLs are unquoted via + * FT.EXPLAINCLI intentionally raising NotImplementedError + * Restoring ZRANGE desc for Redis < 6.2.0 (#1697) + * Response parsing occasionally fails to parse floats (#1692) + * Re-enabling read-the-docs (#1707) + * Call HSET after FT.CREATE to avoid keyspace scan (#1706) + * Unit tests fixes for compatibility (#1703) + * Improve documentation about Locks (#1701) + * Fixes to allow --redis-url to pass through all tests (#1700) + * Fix unit tests running against Redis 4.0.0 (#1699) + * Search alias test fix (#1695) + * Adding RediSearch/RedisJSON tests (#1691) + * Updating codecov rules (#1689) + * Tests to validate custom JSON decoders (#1681) + * Added breaking icon to release drafter (#1702) + * Removing dependency on six (#1676) + * Re-enable pipeline support for JSON and TimeSeries (#1674) + * Export Sentinel, and SSL like other classes (#1671) + * Restore zrange functionality for older versions of Redis (#1670) + * Fixed garbage collection deadlock (#1578) + * Tests to validate built python packages (#1678) + * Sleep for flaky search test (#1680) + * Test function renames, to match standards (#1679) + * Docstring improvements for Redis class (#1675) + * Fix georadius tests (#1672) + * Improvements to JSON coverage (#1666) + * Add python_requires setuptools check for python > 3.6 (#1656) + * SMISMEMBER support (#1667) + * Exposing the module version in loaded_modules (#1648) + * RedisTimeSeries support (#1652) + * Support for json multipath ($) (#1663) + * Added boolean parsing to PEXPIRE and PEXPIREAT (#1665) + * Add python_requires setuptools check for python > 3.6 (#1656) + * Adding vulture for static analysis (#1655) + * Starting to clean the docs (#1657) + * Update README.md (#1654) + * Adding description format for package (#1651) + * Publish to pypi as releases are generated with the release drafter (#1647) + * Restore actions to prs (#1653) + * Fixing the package to include commands (#1649) + * Re-enabling codecov as part of CI process (#1646) + * Adding support for redisearch (#1640) Thanks @chayim + * redisjson support (#1636) Thanks @chayim + * Sentinel: Add SentinelManagedSSLConnection (#1419) Thanks @AbdealiJK + * Enable floating parameters in SET (ex and px) (#1635) Thanks @AvitalFineRedis + * Add warning when hiredis not installed. Recommend installation. (#1621) Thanks @adiamzn + * Raising NotImplementedError for SCRIPT DEBUG and DEBUG SEGFAULT (#1624) Thanks @chayim + * CLIENT REDIR command support (#1623) Thanks @chayim + * REPLICAOF command implementation (#1622) Thanks @chayim + * Add support to NX XX and CH to GEOADD (#1605) Thanks @AvitalFineRedis + * Add support to ZRANGE and ZRANGESTORE parameters (#1603) Thanks @AvitalFineRedis + * Pre 6.2 redis should default to None for script flush (#1641) Thanks @chayim + * Add FULL option to XINFO SUMMARY (#1638) Thanks @agusdmb + * Geosearch test should use any=True (#1594) Thanks @Andrew-Chen-Wang + * Removing packaging dependency (#1626) Thanks @chayim + * Fix client_kill_filter docs for skimpy (#1596) Thanks @Andrew-Chen-Wang + * Normalize minid and maxlen docs (#1593) Thanks @Andrew-Chen-Wang + * Update docs for multiple usernames for ACL DELUSER (#1595) Thanks @Andrew-Chen-Wang + * Fix grammar of get param in set command (#1588) Thanks @Andrew-Chen-Wang + * Fix docs for client_kill_filter (#1584) Thanks @Andrew-Chen-Wang + * Convert README & CONTRIBUTING from rst to md (#1633) Thanks @davidylee + * Test BYLEX param in zrangestore (#1634) Thanks @AvitalFineRedis + * Tox integrations with invoke and docker (#1632) Thanks @chayim + * Adding the release drafter to help simplify release notes (#1618). Thanks @chayim + * BACKWARDS INCOMPATIBLE: Removed support for end of life Python 2.7. #1318 + * BACKWARDS INCOMPATIBLE: All values within Redis URLs are unquoted via urllib.parse.unquote. Prior versions of redis-py supported this by specifying the ``decode_components`` flag to the ``from_url`` functions. This is now done by default and cannot be disabled. #589 - * POTENTIALLY INCOMPATIBLE: Redis commands were moved into a mixin + * POTENTIALLY INCOMPATIBLE: Redis commands were moved into a mixin (see commands.py). Anyone importing ``redis.client`` to access commands directly should import ``redis.commands``. #1534, #1550 - * Removed technical debt on REDIS_6_VERSION placeholder. Thanks @chayim #1582. - * Various docus fixes. Thanks @Andrew-Chen-Wang #1585, #1586. - * Support for LOLWUT command, available since Redis 5.0.0. + * Removed technical debt on REDIS_6_VERSION placeholder. Thanks @chayim #1582. + * Various docus fixes. Thanks @Andrew-Chen-Wang #1585, #1586. + * Support for LOLWUT command, available since Redis 5.0.0. Thanks @brainix #1568. - * Added support for CLIENT REPLY, available in Redis 3.2.0. + * Added support for CLIENT REPLY, available in Redis 3.2.0. Thanks @chayim #1581. - * Support for Auto-reconnect PubSub on get_message. Thanks @luhn #1574. - * Fix RST syntax error in README/ Thanks @JanCBrammer #1451. - * IDLETIME and FREQ support for RESTORE. Thanks @chayim #1580. - * Supporting args with MODULE LOAD. Thanks @chayim #1579. - * Updating RedisLabs with Redis. Thanks @gkorland #1575. - * Added support for ASYNC to SCRIPT FLUSH available in Redis 6.2.0. + * Support for Auto-reconnect PubSub on get_message. Thanks @luhn #1574. + * Fix RST syntax error in README/ Thanks @JanCBrammer #1451. + * IDLETIME and FREQ support for RESTORE. Thanks @chayim #1580. + * Supporting args with MODULE LOAD. Thanks @chayim #1579. + * Updating RedisLabs with Redis. Thanks @gkorland #1575. + * Added support for ASYNC to SCRIPT FLUSH available in Redis 6.2.0. Thanks @chayim. #1567 - * Added CLIENT LIST fix to support multiple client ids available in + * Added CLIENT LIST fix to support multiple client ids available in Redis 2.8.12. Thanks @chayim #1563. - * Added DISCARD support for pipelines available in Redis 2.0.0. + * Added DISCARD support for pipelines available in Redis 2.0.0. Thanks @chayim #1565. - * Added ACL DELUSER support for deleting lists of users available in + * Added ACL DELUSER support for deleting lists of users available in Redis 6.2.0. Thanks @chayim. #1562 - * Added CLIENT TRACKINFO support available in Redis 6.2.0. + * Added CLIENT TRACKINFO support available in Redis 6.2.0. Thanks @chayim. #1560 - * Added GEOSEARCH and GEOSEARCHSTORE support available in Redis 6.2.0. + * Added GEOSEARCH and GEOSEARCHSTORE support available in Redis 6.2.0. Thanks @AvitalFine Redis. #1526 - * Added LPUSHX support for lists available in Redis 4.0.0. + * Added LPUSHX support for lists available in Redis 4.0.0. Thanks @chayim. #1559 - * Added support for QUIT available in Redis 1.0.0. + * Added support for QUIT available in Redis 1.0.0. Thanks @chayim. #1558 - * Added support for COMMAND COUNT available in Redis 2.8.13. + * Added support for COMMAND COUNT available in Redis 2.8.13. Thanks @chayim. #1554. - * Added CREATECONSUMER support for XGROUP available in Redis 6.2.0. + * Added CREATECONSUMER support for XGROUP available in Redis 6.2.0. Thanks @AvitalFineRedis. #1553 - * Including slowly complexity in INFO if available. + * Including slowly complexity in INFO if available. Thanks @ian28223 #1489. - * Added support for STRALGO available in Redis 6.0.0. + * Added support for STRALGO available in Redis 6.0.0. Thanks @AvitalFineRedis. #1528 - * Addes support for ZMSCORE available in Redis 6.2.0. + * Addes support for ZMSCORE available in Redis 6.2.0. Thanks @2014BDuck and @jiekun.zhu. #1437 - * Support MINID and LIMIT on XADD available in Redis 6.2.0. + * Support MINID and LIMIT on XADD available in Redis 6.2.0. Thanks @AvitalFineRedis. #1548 - * Added sentinel commands FLUSHCONFIG, CKQUORUM, FAILOVER, and RESET + * Added sentinel commands FLUSHCONFIG, CKQUORUM, FAILOVER, and RESET available in Redis 2.8.12. Thanks @otherpirate. #834 - * Migrated Version instead of StrictVersion for Python 3.10. + * Migrated Version instead of StrictVersion for Python 3.10. Thanks @tirkarthi. #1552 - * Added retry mechanism with backoff. Thanks @nbraun-amazon. #1494 - * Migrated commands to a mixin. Thanks @chayim. #1534 - * Added support for ZUNION, available in Redis 6.2.0. Thanks + * Added retry mechanism with backoff. Thanks @nbraun-amazon. #1494 + * Migrated commands to a mixin. Thanks @chayim. #1534 + * Added support for ZUNION, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1522 - * Added support for CLIENT LIST with ID, available in Redis 6.2.0. + * Added support for CLIENT LIST with ID, available in Redis 6.2.0. Thanks @chayim. #1505 - * Added support for MINID and LIMIT with xtrim, available in Reds 6.2.0. + * Added support for MINID and LIMIT with xtrim, available in Reds 6.2.0. Thanks @chayim. #1508 - * Implemented LMOVE and BLMOVE commands, available in Redis 6.2.0. + * Implemented LMOVE and BLMOVE commands, available in Redis 6.2.0. Thanks @chayim. #1504 - * Added GET argument to SET command, available in Redis 6.2.0. + * Added GET argument to SET command, available in Redis 6.2.0. Thanks @2014BDuck. #1412 - * Documentation fixes. Thanks @enjoy-binbin @jonher937. #1496 #1532 - * Added support for XAUTOCLAIM, available in Redis 6.2.0. + * Documentation fixes. Thanks @enjoy-binbin @jonher937. #1496 #1532 + * Added support for XAUTOCLAIM, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1529 - * Added IDLE support for XPENDING, available in Redis 6.2.0. + * Added IDLE support for XPENDING, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1523 - * Add a count parameter to lpop/rpop, available in Redis 6.2.0. + * Add a count parameter to lpop/rpop, available in Redis 6.2.0. Thanks @wavenator. #1487 - * Added a (pypy) trove classifier for Python 3.9. + * Added a (pypy) trove classifier for Python 3.9. Thanks @D3X. #1535 - * Added ZINTER support, available in Redis 6.2.0. + * Added ZINTER support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1520 - * Added ZINTER support, available in Redis 6.2.0. + * Added ZINTER support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1520 - * Added ZDIFF and ZDIFFSTORE support, available in Redis 6.2.0. + * Added ZDIFF and ZDIFFSTORE support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1518 - * Added ZRANGESTORE support, available in Redis 6.2.0. + * Added ZRANGESTORE support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1521 - * Added LT and GT support for ZADD, available in Redis 6.2.0. + * Added LT and GT support for ZADD, available in Redis 6.2.0. Thanks @chayim. #1509 - * Added ZRANDMEMBER support, available in Redis 6.2.0. + * Added ZRANDMEMBER support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1519 - * Added GETDEL support, available in Redis 6.2.0. + * Added GETDEL support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1514 - * Added CLIENT KILL laddr filter, available in Redis 6.2.0. + * Added CLIENT KILL laddr filter, available in Redis 6.2.0. Thanks @chayim. #1506 - * Added CLIENT UNPAUSE, available in Redis 6.2.0. + * Added CLIENT UNPAUSE, available in Redis 6.2.0. Thanks @chayim. #1512 - * Added NOMKSTREAM support for XADD, available in Redis 6.2.0. + * Added NOMKSTREAM support for XADD, available in Redis 6.2.0. Thanks @chayim. #1507 - * Added HRANDFIELD support, available in Redis 6.2.0. + * Added HRANDFIELD support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1513 - * Added CLIENT INFO support, available in Redis 6.2.0. + * Added CLIENT INFO support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1517 - * Added GETEX support, available in Redis 6.2.0. + * Added GETEX support, available in Redis 6.2.0. Thanks @AvitalFineRedis. #1515 - * Added support for COPY command, available in Redis 6.2.0. + * Added support for COPY command, available in Redis 6.2.0. Thanks @malinaa96. #1492 - * Provide a development and testing environment via docker. Thanks + * Provide a development and testing environment via docker. Thanks @abrookins. #1365 - * Added support for the LPOS command available in Redis 6.0.6. Thanks + * Added support for the LPOS command available in Redis 6.0.6. Thanks @aparcar #1353/#1354 - * Added support for the ACL LOG command available in Redis 6. Thanks + * Added support for the ACL LOG command available in Redis 6. Thanks @2014BDuck. #1307 - * Added support for ABSTTL option of the RESTORE command available in + * Added support for ABSTTL option of the RESTORE command available in Redis 5.0. Thanks @charettes. #1423 * 3.5.3 (June 1, 2020) - * Restore try/except clauses to __del__ methods. These will be removed + * Restore try/except clauses to __del__ methods. These will be removed in 4.0 when more explicit resource management if enforced. #1339 - * Update the master_address when Sentinels promote a new master. #847 - * Update SentinelConnectionPool to not forcefully disconnect other in-use + * Update the master_address when Sentinels promote a new master. #847 + * Update SentinelConnectionPool to not forcefully disconnect other in-use connections which can negatively affect threaded applications. #1345 * 3.5.2 (May 14, 2020) - * Tune the locking in ConnectionPool.get_connection so that the lock is + * Tune the locking in ConnectionPool.get_connection so that the lock is not held while waiting for the socket to establish and validate the TCP connection. * 3.5.1 (May 9, 2020) - * Fix for HSET argument validation to allow any non-None key. Thanks + * Fix for HSET argument validation to allow any non-None key. Thanks @AleksMat, #1337, #1341 * 3.5.0 (April 29, 2020) - * Removed exception trapping from __del__ methods. redis-py objects that + * Removed exception trapping from __del__ methods. redis-py objects that hold various resources implement __del__ cleanup methods to release those resources when the object goes out of scope. This provides a fallback for when these objects aren't explicitly closed by user code. Prior to this change any errors encountered in closing these resources would be hidden from the user. Thanks @jdufresne. #1281 - * Expanded support for connection strings specifying a username connecting + * Expanded support for connection strings specifying a username connecting to pre-v6 servers. #1274 - * Optimized Lock's blocking_timeout and sleep. If the lock cannot be + * Optimized Lock's blocking_timeout and sleep. If the lock cannot be acquired and the sleep value would cause the loop to sleep beyond blocking_timeout, fail immediately. Thanks @clslgrnc. #1263 - * Added support for passing Python memoryviews to Redis command args that + * Added support for passing Python memoryviews to Redis command args that expect strings or bytes. The memoryview instance is sent directly to the socket such that there are zero copies made of the underlying data during command packing. Thanks @Cody-G. #1265, #1285 - * HSET command now can accept multiple pairs. HMSET has been marked as + * HSET command now can accept multiple pairs. HMSET has been marked as deprecated now. Thanks to @laixintao #1271 - * Don't manually DISCARD when encountering an ExecAbortError. + * Don't manually DISCARD when encountering an ExecAbortError. Thanks @nickgaya, #1300/#1301 - * Reset the watched state of pipelines after calling exec. This saves + * Reset the watched state of pipelines after calling exec. This saves a roundtrip to the server by not having to call UNWATCH within Pipeline.reset(). Thanks @nickgaya, #1299/#1302 - * Added the KEEPTTL option for the SET command. Thanks + * Added the KEEPTTL option for the SET command. Thanks @laixintao #1304/#1280 - * Added the MEMORY STATS command. #1268 - * Lock.extend() now has a new option, `replace_ttl`. When False (the + * Added the MEMORY STATS command. #1268 + * Lock.extend() now has a new option, `replace_ttl`. When False (the default), Lock.extend() adds the `additional_time` to the lock's existing TTL. When replace_ttl=True, the lock's existing TTL is replaced with the value of `additional_time`. - * Add testing and support for PyPy. + * Add testing and support for PyPy. * 3.4.1 - * Move the username argument in the Redis and Connection classes to the + * Move the username argument in the Redis and Connection classes to the end of the argument list. This helps those poor souls that specify all their connection options as non-keyword arguments. #1276 - * Prior to ACL support, redis-py ignored the username component of + * Prior to ACL support, redis-py ignored the username component of Connection URLs. With ACL support, usernames are no longer ignored and are used to authenticate against an ACL rule. Some cloud vendors with managed Redis instances (like Heroku) provide connection URLs with a @@ -358,33 +378,33 @@ username to Redis servers < 6.0.0 results in an error. Attempt to detect this condition and retry the AUTH command with only the password such that authentication continues to work for these users. #1274 - * Removed the __eq__ hooks to Redis and ConnectionPool that were added + * Removed the __eq__ hooks to Redis and ConnectionPool that were added in 3.4.0. This ended up being a bad idea as two separate connection pools be considered equal yet manage a completely separate set of connections. * 3.4.0 - * Allow empty pipelines to be executed if there are WATCHed keys. + * Allow empty pipelines to be executed if there are WATCHed keys. This is a convenient way to test if any of the watched keys changed without actually running any other commands. Thanks @brianmaissy. #1233, #1234 - * Removed support for end of life Python 3.4. - * Added support for all ACL commands in Redis 6. Thanks @IAmATeaPot418 + * Removed support for end of life Python 3.4. + * Added support for all ACL commands in Redis 6. Thanks @IAmATeaPot418 for helping. - * Pipeline instances now always evaluate to True. Prior to this change, + * Pipeline instances now always evaluate to True. Prior to this change, pipeline instances relied on __len__ for boolean evaluation which meant that pipelines with no commands on the stack would be considered False. #994 - * Client instances and Connection pools now support a 'client_name' + * Client instances and Connection pools now support a 'client_name' argument. If supplied, all connections created will call CLIENT SETNAME as soon as the connection is opened. Thanks to @Habbie for supplying the basis of this change. #802 - * Added the 'ssl_check_hostname' argument to specify whether SSL + * Added the 'ssl_check_hostname' argument to specify whether SSL connections should require the server hostname to match the hostname specified in the SSL cert. By default 'ssl_check_hostname' is False for backwards compatibility. #1196 - * Slightly optimized command packing. Thanks @Deneby67. #1255 - * Added support for the TYPE argument to SCAN. Thanks @netocp. #1220 - * Better thread and fork safety in ConnectionPool and + * Slightly optimized command packing. Thanks @Deneby67. #1255 + * Added support for the TYPE argument to SCAN. Thanks @netocp. #1220 + * Better thread and fork safety in ConnectionPool and BlockingConnectionPool. Added better locking to synchronize critical sections rather than relying on CPython-specific implementation details relating to atomic operations. Adjusted how the pools identify and @@ -392,636 +412,636 @@ raised by child processes in the very unlikely chance that a deadlock is encountered. Thanks @gmbnomis, @mdellweg, @yht804421715. #1270, #1138, #1178, #906, #1262 - * Added __eq__ hooks to the Redis and ConnectionPool classes. + * Added __eq__ hooks to the Redis and ConnectionPool classes. Thanks @brainix. #1240 * 3.3.11 - * Further fix for the SSLError -> TimeoutError mapping to work + * Further fix for the SSLError -> TimeoutError mapping to work on obscure releases of Python 2.7. * 3.3.10 - * Fixed a potential error handling bug for the SSLError -> TimeoutError + * Fixed a potential error handling bug for the SSLError -> TimeoutError mapping introduced in 3.3.9. Thanks @zbristow. #1224 * 3.3.9 - * Mapped Python 2.7 SSLError to TimeoutError where appropriate. Timeouts + * Mapped Python 2.7 SSLError to TimeoutError where appropriate. Timeouts should now consistently raise TimeoutErrors on Python 2.7 for both unsecured and secured connections. Thanks @zbristow. #1222 * 3.3.8 - * Fixed MONITOR parsing to properly parse IPv6 client addresses, unix + * Fixed MONITOR parsing to properly parse IPv6 client addresses, unix socket connections and commands issued from Lua. Thanks @kukey. #1201 * 3.3.7 - * Fixed a regression introduced in 3.3.0 where socket.error exceptions + * Fixed a regression introduced in 3.3.0 where socket.error exceptions (or subclasses) could potentially be raised instead of redis.exceptions.ConnectionError. #1202 * 3.3.6 - * Fixed a regression in 3.3.5 that caused PubSub.get_message() to raise + * Fixed a regression in 3.3.5 that caused PubSub.get_message() to raise a socket.timeout exception when passing a timeout value. #1200 * 3.3.5 - * Fix an issue where socket.timeout errors could be handled by the wrong + * Fix an issue where socket.timeout errors could be handled by the wrong exception handler in Python 2.7. * 3.3.4 - * More specifically identify nonblocking read errors for both SSL and + * More specifically identify nonblocking read errors for both SSL and non-SSL connections. 3.3.1, 3.3.2 and 3.3.3 on Python 2.7 could potentially mask a ConnectionError. #1197 * 3.3.3 - * The SSL module in Python < 2.7.9 handles non-blocking sockets + * The SSL module in Python < 2.7.9 handles non-blocking sockets differently than 2.7.9+. This patch accommodates older versions. #1197 * 3.3.2 - * Further fixed a regression introduced in 3.3.0 involving SSL and + * Further fixed a regression introduced in 3.3.0 involving SSL and non-blocking sockets. #1197 * 3.3.1 - * Fixed a regression introduced in 3.3.0 involving SSL and non-blocking + * Fixed a regression introduced in 3.3.0 involving SSL and non-blocking sockets. #1197 * 3.3.0 - * Resolve a race condition with the PubSubWorkerThread. #1150 - * Cleanup socket read error messages. Thanks Vic Yu. #1159 - * Cleanup the Connection's selector correctly. Thanks Bruce Merry. #1153 - * Added a Monitor object to make working with MONITOR output easy. + * Resolve a race condition with the PubSubWorkerThread. #1150 + * Cleanup socket read error messages. Thanks Vic Yu. #1159 + * Cleanup the Connection's selector correctly. Thanks Bruce Merry. #1153 + * Added a Monitor object to make working with MONITOR output easy. Thanks Roey Prat #1033 - * Internal cleanup: Removed the legacy Token class which was necessary + * Internal cleanup: Removed the legacy Token class which was necessary with older version of Python that are no longer supported. #1066 - * Response callbacks are now case insensitive. This allows users that + * Response callbacks are now case insensitive. This allows users that call Redis.execute_command() directly to pass lower-case command names and still get reasonable responses. #1168 - * Added support for hiredis-py 1.0.0 encoding error support. This should + * Added support for hiredis-py 1.0.0 encoding error support. This should make the PythonParser and the HiredisParser behave identically when encountering encoding errors. Thanks Brian Candler. #1161/#1162 - * All authentication errors now properly raise AuthenticationError. + * All authentication errors now properly raise AuthenticationError. AuthenticationError is now a subclass of ConnectionError, which will cause the connection to be disconnected and cleaned up appropriately. #923 - * Add READONLY and READWRITE commands. Thanks @theodesp. #1114 - * Remove selectors in favor of nonblocking sockets. Selectors had + * Add READONLY and READWRITE commands. Thanks @theodesp. #1114 + * Remove selectors in favor of nonblocking sockets. Selectors had issues in some environments including eventlet and gevent. This should resolve those issues with no other side effects. - * Fixed an issue with XCLAIM and previously claimed but not removed + * Fixed an issue with XCLAIM and previously claimed but not removed messages. Thanks @thomdask. #1192/#1191 - * Allow for single connection client instances. These instances + * Allow for single connection client instances. These instances are not thread safe but offer other benefits including a subtle performance increase. - * Added extensive health checks that keep the connections lively. + * Added extensive health checks that keep the connections lively. Passing the "health_check_interval=N" option to the Redis client class or to a ConnectionPool ensures that a round trip PING/PONG is successful before any command if the underlying connection has been idle for more than N seconds. ConnectionErrors and TimeoutErrors are automatically retried once for health checks. - * Changed the PubSubWorkerThread to use a threading.Event object rather + * Changed the PubSubWorkerThread to use a threading.Event object rather than a boolean to control the thread's life cycle. Thanks Timothy Rule. #1194/#1195. - * Fixed a bug in Pipeline error handling that would incorrectly retry + * Fixed a bug in Pipeline error handling that would incorrectly retry ConnectionErrors. * 3.2.1 - * Fix SentinelConnectionPool to work in multiprocess/forked environments. + * Fix SentinelConnectionPool to work in multiprocess/forked environments. * 3.2.0 - * Added support for `select.poll` to test whether data can be read + * Added support for `select.poll` to test whether data can be read on a socket. This should allow for significantly more connections to be used with pubsub. Fixes #486/#1115 - * Attempt to guarantee that the ConnectionPool hands out healthy + * Attempt to guarantee that the ConnectionPool hands out healthy connections. Healthy connections are those that have an established socket connection to the Redis server, are ready to accept a command and have no data available to read. Fixes #1127/#886 - * Use the socket.IPPROTO_TCP constant instead of socket.SOL_TCP. + * Use the socket.IPPROTO_TCP constant instead of socket.SOL_TCP. IPPROTO_TCP is available on more interpreters (Jython for instance). Thanks @Junnplus. #1130 - * Fixed a regression introduced in 3.0 that mishandles exceptions not + * Fixed a regression introduced in 3.0 that mishandles exceptions not derived from the base Exception class. KeyboardInterrupt and gevent.timeout notable. Thanks Christian Fersch. #1128/#1129 - * Significant improvements to handing connections with forked processes. + * Significant improvements to handing connections with forked processes. Parent and child processes no longer trample on each others' connections. Thanks to Jay Rolette for the patch and highlighting this issue. #504/#732/#784/#863 - * PythonParser no longer closes the associated connection's socket. The + * PythonParser no longer closes the associated connection's socket. The connection itself will close the socket. #1108/#1085 * 3.1.0 - * Connection URLs must have one of the following schemes: + * Connection URLs must have one of the following schemes: redis://, rediss://, unix://. Thanks @jdupl123. #961/#969 - * Fixed an issue with retry_on_timeout logic that caused some TimeoutErrors + * Fixed an issue with retry_on_timeout logic that caused some TimeoutErrors to be retried. Thanks Aaron Yang. #1022/#1023 - * Added support for SNI for SSL. Thanks @oridistor and Roey Prat. #1087 - * Fixed ConnectionPool repr for pools with no connections. Thanks + * Added support for SNI for SSL. Thanks @oridistor and Roey Prat. #1087 + * Fixed ConnectionPool repr for pools with no connections. Thanks Cody Scott. #1043/#995 - * Fixed GEOHASH to return a None value when specifying a place that + * Fixed GEOHASH to return a None value when specifying a place that doesn't exist on the server. Thanks @guybe7. #1126 - * Fixed XREADGROUP to return an empty dictionary for messages that + * Fixed XREADGROUP to return an empty dictionary for messages that have been deleted but still exist in the unacknowledged queue. Thanks @xeizmendi. #1116 - * Added an owned method to Lock objects. owned returns a boolean + * Added an owned method to Lock objects. owned returns a boolean indicating whether the current lock instance still owns the lock. Thanks Dave Johansen. #1112 - * Allow lock.acquire() to accept an optional token argument. If + * Allow lock.acquire() to accept an optional token argument. If provided, the token argument is used as the unique value used to claim the lock. Thankd Dave Johansen. #1112 - * Added a reacquire method to Lock objects. reacquire attempts to renew + * Added a reacquire method to Lock objects. reacquire attempts to renew the lock such that the timeout is extended to the same value that the lock was initially acquired with. Thanks Ihor Kalnytskyi. #1014 - * Stream names found within XREAD and XREADGROUP responses now properly + * Stream names found within XREAD and XREADGROUP responses now properly respect the decode_responses flag. - * XPENDING_RANGE now requires the user the specify the min, max and + * XPENDING_RANGE now requires the user the specify the min, max and count arguments. Newer versions of Redis prevent count from being infinite so it's left to the user to specify these values explicitly. - * ZADD now returns None when xx=True and incr=True and an element + * ZADD now returns None when xx=True and incr=True and an element is specified that doesn't exist in the sorted set. This matches what the server returns in this case. #1084 - * Added client_kill_filter that accepts various filters to identify + * Added client_kill_filter that accepts various filters to identify and kill clients. Thanks Theofanis Despoudis. #1098 - * Fixed a race condition that occurred when unsubscribing and + * Fixed a race condition that occurred when unsubscribing and resubscribing to the same channel or pattern in rapid succession. Thanks Marcin Raczyński. #764 - * Added a LockNotOwnedError that is raised when trying to extend or + * Added a LockNotOwnedError that is raised when trying to extend or release a lock that is no longer owned. This is a subclass of LockError so previous code should continue to work as expected. Thanks Joshua Harlow. #1095 - * Fixed a bug in GEORADIUS that forced decoding of places without + * Fixed a bug in GEORADIUS that forced decoding of places without respecting the decode_responses option. Thanks Bo Bayles. #1082 * 3.0.1 - * Fixed regression with UnixDomainSocketConnection caused by 3.0.0. + * Fixed regression with UnixDomainSocketConnection caused by 3.0.0. Thanks Jyrki Muukkonen - * Fixed an issue with the new asynchronous flag on flushdb and flushall. + * Fixed an issue with the new asynchronous flag on flushdb and flushall. Thanks rogeryen - * Updated Lock.locked() method to indicate whether *any* process has + * Updated Lock.locked() method to indicate whether *any* process has acquired the lock, not just the current one. This is in line with the behavior of threading.Lock. Thanks Alan Justino da Silva * 3.0.0 BACKWARDS INCOMPATIBLE CHANGES - * When using a Lock as a context manager and the lock fails to be acquired + * When using a Lock as a context manager and the lock fails to be acquired a LockError is now raised. This prevents the code block inside the context manager from being executed if the lock could not be acquired. - * Renamed LuaLock to Lock. - * Removed the pipeline based Lock implementation in favor of the LuaLock + * Renamed LuaLock to Lock. + * Removed the pipeline based Lock implementation in favor of the LuaLock implementation. - * Only bytes, strings and numbers (ints, longs and floats) are acceptable + * Only bytes, strings and numbers (ints, longs and floats) are acceptable for keys and values. Previously redis-py attempted to cast other types to str() and store the result. This caused must confusion and frustration when passing boolean values (cast to 'True' and 'False') or None values (cast to 'None'). It is now the user's responsibility to cast all key names and values to bytes, strings or numbers before passing the value to redis-py. - * The StrictRedis class has been renamed to Redis. StrictRedis will + * The StrictRedis class has been renamed to Redis. StrictRedis will continue to exist as an alias of Redis for the foreseeable future. - * The legacy Redis client class has been removed. It caused much confusion + * The legacy Redis client class has been removed. It caused much confusion to users. - * ZINCRBY arguments 'value' and 'amount' have swapped order to match the + * ZINCRBY arguments 'value' and 'amount' have swapped order to match the the Redis server. The new argument order is: keyname, amount, value. - * MGET no longer raises an error if zero keys are passed in. Instead an + * MGET no longer raises an error if zero keys are passed in. Instead an empty list is returned. - * MSET and MSETNX now require all keys/values to be specified in a single + * MSET and MSETNX now require all keys/values to be specified in a single dictionary argument named mapping. This was changed to allow for future options to these commands in the future. - * ZADD now requires all element names/scores be specified in a single + * ZADD now requires all element names/scores be specified in a single dictionary argument named mapping. This was required to allow the NX, XX, CH and INCR options to be specified. - * ssl_cert_reqs now has a default value of 'required' by default. This + * ssl_cert_reqs now has a default value of 'required' by default. This should make connecting to a remote Redis server over SSL more secure. Thanks u2mejc - * Removed support for EOL Python 2.6 and 3.3. Thanks jdufresne + * Removed support for EOL Python 2.6 and 3.3. Thanks jdufresne OTHER CHANGES - * Added missing DECRBY command. Thanks derek-dchu - * CLUSTER INFO and CLUSTER NODES responses are now properly decoded to + * Added missing DECRBY command. Thanks derek-dchu + * CLUSTER INFO and CLUSTER NODES responses are now properly decoded to strings. - * Added a 'locked()' method to Lock objects. This method returns True + * Added a 'locked()' method to Lock objects. This method returns True if the lock has been acquired and owned by the current process, otherwise False. - * EXISTS now supports multiple keys. It's return value is now the number + * EXISTS now supports multiple keys. It's return value is now the number of keys in the list that exist. - * Ensure all commands can accept key names as bytes. This fixes issues + * Ensure all commands can accept key names as bytes. This fixes issues with BLPOP, BRPOP and SORT. - * All errors resulting from bad user input are raised as DataError + * All errors resulting from bad user input are raised as DataError exceptions. DataError is a subclass of RedisError so this should be transparent to anyone previously catching these. - * Added support for NX, XX, CH and INCR options to ZADD - * Added support for the MIGRATE command - * Added support for the MEMORY USAGE and MEMORY PURGE commands. Thanks + * Added support for NX, XX, CH and INCR options to ZADD + * Added support for the MIGRATE command + * Added support for the MEMORY USAGE and MEMORY PURGE commands. Thanks Itamar Haber - * Added support for the 'asynchronous' argument to FLUSHDB and FLUSHALL + * Added support for the 'asynchronous' argument to FLUSHDB and FLUSHALL commands. Thanks Itamar Haber - * Added support for the BITFIELD command. Thanks Charles Leifer and + * Added support for the BITFIELD command. Thanks Charles Leifer and Itamar Haber - * Improved performance on pipeline requests with large chunks of data. + * Improved performance on pipeline requests with large chunks of data. Thanks tzickel - * Fixed test suite to not fail if another client is connected to the + * Fixed test suite to not fail if another client is connected to the server the tests are running against. - * Added support for SWAPDB. Thanks Itamar Haber - * Added support for all STREAM commands. Thanks Roey Prat and Itamar Haber - * SHUTDOWN now accepts the 'save' and 'nosave' arguments. Thanks + * Added support for SWAPDB. Thanks Itamar Haber + * Added support for all STREAM commands. Thanks Roey Prat and Itamar Haber + * SHUTDOWN now accepts the 'save' and 'nosave' arguments. Thanks dwilliams-kenzan - * Added support for ZPOPMAX, ZPOPMIN, BZPOPMAX, BZPOPMIN. Thanks + * Added support for ZPOPMAX, ZPOPMIN, BZPOPMAX, BZPOPMIN. Thanks Itamar Haber - * Added support for the 'type' argument in CLIENT LIST. Thanks Roey Prat - * Added support for CLIENT PAUSE. Thanks Roey Prat - * Added support for CLIENT ID and CLIENT UNBLOCK. Thanks Itamar Haber - * GEODIST now returns a None value when referencing a place that does + * Added support for the 'type' argument in CLIENT LIST. Thanks Roey Prat + * Added support for CLIENT PAUSE. Thanks Roey Prat + * Added support for CLIENT ID and CLIENT UNBLOCK. Thanks Itamar Haber + * GEODIST now returns a None value when referencing a place that does not exist. Thanks qingping209 - * Added a ping() method to pubsub objects. Thanks krishan-carbon - * Fixed a bug with keys in the INFO dict that contained ':' symbols. + * Added a ping() method to pubsub objects. Thanks krishan-carbon + * Fixed a bug with keys in the INFO dict that contained ':' symbols. Thanks mzalimeni - * Fixed the select system call retry compatibility with Python 2.x. + * Fixed the select system call retry compatibility with Python 2.x. Thanks lddubeau - * max_connections is now a valid querystring argument for creating + * max_connections is now a valid querystring argument for creating connection pools from URLs. Thanks mmaslowskicc - * Added the UNLINK command. Thanks yozel - * Added socket_type option to Connection for configurability. + * Added the UNLINK command. Thanks yozel + * Added socket_type option to Connection for configurability. Thanks garlicnation - * Lock.do_acquire now atomically sets acquires the lock and sets the + * Lock.do_acquire now atomically sets acquires the lock and sets the expire value via set(nx=True, px=timeout). Thanks 23doors - * Added 'count' argument to SPOP. Thanks AlirezaSadeghi - * Fixed an issue parsing client_list responses that contained an '='. + * Added 'count' argument to SPOP. Thanks AlirezaSadeghi + * Fixed an issue parsing client_list responses that contained an '='. Thanks swilly22 * 2.10.6 - * Various performance improvements. Thanks cjsimpson - * Fixed a bug with SRANDMEMBER where the behavior for `number=0` did + * Various performance improvements. Thanks cjsimpson + * Fixed a bug with SRANDMEMBER where the behavior for `number=0` did not match the spec. Thanks Alex Wang - * Added HSTRLEN command. Thanks Alexander Putilin - * Added the TOUCH command. Thanks Anis Jonischkeit - * Remove unnecessary calls to the server when registering Lua scripts. + * Added HSTRLEN command. Thanks Alexander Putilin + * Added the TOUCH command. Thanks Anis Jonischkeit + * Remove unnecessary calls to the server when registering Lua scripts. Thanks Ben Greenberg - * SET's EX and PX arguments now allow values of zero. Thanks huangqiyin - * Added PUBSUB {CHANNELS, NUMPAT, NUMSUB} commands. Thanks Angus Pearson - * PubSub connections that encounter `InterruptedError`s now + * SET's EX and PX arguments now allow values of zero. Thanks huangqiyin + * Added PUBSUB {CHANNELS, NUMPAT, NUMSUB} commands. Thanks Angus Pearson + * PubSub connections that encounter `InterruptedError`s now retry automatically. Thanks Carlton Gibson and Seth M. Larson - * LPUSH and RPUSH commands run on PyPy now correctly returns the number + * LPUSH and RPUSH commands run on PyPy now correctly returns the number of items of the list. Thanks Jeong YunWon - * Added support to automatically retry socket EINTR errors. Thanks + * Added support to automatically retry socket EINTR errors. Thanks Thomas Steinacher - * PubSubWorker threads started with `run_in_thread` are now daemonized + * PubSubWorker threads started with `run_in_thread` are now daemonized so the thread shuts down when the running process goes away. Thanks Keith Ainsworth - * Added support for GEO commands. Thanks Pau Freixes, Alex DeBrie and + * Added support for GEO commands. Thanks Pau Freixes, Alex DeBrie and Abraham Toriz - * Made client construction from URLs smarter. Thanks Tim Savage - * Added support for CLUSTER * commands. Thanks Andy Huang - * The RESTORE command now accepts an optional `replace` boolean. + * Made client construction from URLs smarter. Thanks Tim Savage + * Added support for CLUSTER * commands. Thanks Andy Huang + * The RESTORE command now accepts an optional `replace` boolean. Thanks Yoshinari Takaoka - * Attempt to connect to a new Sentinel if a TimeoutError occurs. Thanks + * Attempt to connect to a new Sentinel if a TimeoutError occurs. Thanks Bo Lopker - * Fixed a bug in the client's `__getitem__` where a KeyError would be + * Fixed a bug in the client's `__getitem__` where a KeyError would be raised if the value returned by the server is an empty string. Thanks Javier Candeira. - * Socket timeouts when connecting to a server are now properly raised + * Socket timeouts when connecting to a server are now properly raised as TimeoutErrors. * 2.10.5 - * Allow URL encoded parameters in Redis URLs. Characters like a "/" can + * Allow URL encoded parameters in Redis URLs. Characters like a "/" can now be URL encoded and redis-py will correctly decode them. Thanks Paul Keene. - * Added support for the WAIT command. Thanks https://github.com/eshizhan - * Better shutdown support for the PubSub Worker Thread. It now properly + * Added support for the WAIT command. Thanks + * Better shutdown support for the PubSub Worker Thread. It now properly cleans up the connection, unsubscribes from any channels and patterns previously subscribed to and consumes any waiting messages on the socket. - * Added the ability to sleep for a brief period in the event of a + * Added the ability to sleep for a brief period in the event of a WatchError occurring. Thanks Joshua Harlow. - * Fixed a bug with pipeline error reporting when dealing with characters + * Fixed a bug with pipeline error reporting when dealing with characters in error messages that could not be encoded to the connection's character set. Thanks Hendrik Muhs. - * Fixed a bug in Sentinel connections that would inadvertently connect + * Fixed a bug in Sentinel connections that would inadvertently connect to the master when the connection pool resets. Thanks - https://github.com/df3n5 - * Better timeout support in Pubsub get_message. Thanks Andy Isaacson. - * Fixed a bug with the HiredisParser that would cause the parser to + + * Better timeout support in Pubsub get_message. Thanks Andy Isaacson. + * Fixed a bug with the HiredisParser that would cause the parser to get stuck in an endless loop if a specific number of bytes were delivered from the socket. This fix also increases performance of parsing large responses from the Redis server. - * Added support for ZREVRANGEBYLEX. - * ConnectionErrors are now raised if Redis refuses a connection due to + * Added support for ZREVRANGEBYLEX. + * ConnectionErrors are now raised if Redis refuses a connection due to the maxclients limit being exceeded. Thanks Roman Karpovich. - * max_connections can now be set when instantiating client instances. + * max_connections can now be set when instantiating client instances. Thanks Ohad Perry. * 2.10.4 (skipped due to a PyPI snafu) * 2.10.3 - * Fixed a bug with the bytearray support introduced in 2.10.2. Thanks + * Fixed a bug with the bytearray support introduced in 2.10.2. Thanks Josh Owen. * 2.10.2 - * Added support for Hiredis's new bytearray support. Thanks - https://github.com/tzickel - * POSSIBLE BACKWARDS INCOMPATIBLE CHANGE: Fixed a possible race condition + * Added support for Hiredis's new bytearray support. Thanks + + * POSSIBLE BACKWARDS INCOMPATIBLE CHANGE: Fixed a possible race condition when multiple threads share the same Lock instance with a timeout. Lock tokens are now stored in thread local storage by default. If you have code that acquires a lock in one thread and passes that lock instance to another thread to release it, you need to disable thread local storage. Refer to the doc strings on the Lock class about the thread_local argument information. - * Fixed a regression in from_url where "charset" and "errors" weren't + * Fixed a regression in from_url where "charset" and "errors" weren't valid options. "encoding" and "encoding_errors" are still accepted and preferred. - * The "charset" and "errors" options have been deprecated. Passing + * The "charset" and "errors" options have been deprecated. Passing either to StrictRedis.__init__ or from_url will still work but will also emit a DeprecationWarning. Instead use the "encoding" and "encoding_errors" options. - * Fixed a compatibility bug with Python 3 when the server closes a + * Fixed a compatibility bug with Python 3 when the server closes a connection. - * Added BITPOS command. Thanks https://github.com/jettify. - * Fixed a bug when attempting to send large values to Redis in a Pipeline. + * Added BITPOS command. Thanks . + * Fixed a bug when attempting to send large values to Redis in a Pipeline. * 2.10.1 - * Fixed a bug where Sentinel connections to a server that's no longer a + * Fixed a bug where Sentinel connections to a server that's no longer a master and receives a READONLY error will disconnect and reconnect to the master. * 2.10.0 - * Discontinued support for Python 2.5. Upgrade. You'll be happier. - * The HiRedis parser will now properly raise ConnectionErrors. - * Completely refactored PubSub support. Fixes all known PubSub bugs and + * Discontinued support for Python 2.5. Upgrade. You'll be happier. + * The HiRedis parser will now properly raise ConnectionErrors. + * Completely refactored PubSub support. Fixes all known PubSub bugs and adds a bunch of new features. Docs can be found in the README under the new "Publish / Subscribe" section. - * Added the new HyperLogLog commands (PFADD, PFCOUNT, PFMERGE). Thanks + * Added the new HyperLogLog commands (PFADD, PFCOUNT, PFMERGE). Thanks Pepijn de Vos and Vincent Ohprecio. - * Updated TTL and PTTL commands with Redis 2.8+ semantics. Thanks Markus + * Updated TTL and PTTL commands with Redis 2.8+ semantics. Thanks Markus Kaiserswerth. - * *SCAN commands now return a long (int on Python3) cursor value rather + * *SCAN commands now return a long (int on Python3) cursor value rather than the string representation. This might be slightly backwards - incompatible in code using *SCAN commands loops such as +incompatible in code using*SCAN commands loops such as "while cursor != '0':". - * Added extra *SCAN commands that return iterators instead of the normal + * Added extra *SCAN commands that return iterators instead of the normal [cursor, data] type. Use scan_iter, hscan_iter, sscan_iter, and zscan_iter for iterators. Thanks Mathieu Longtin. - * Added support for SLOWLOG commands. Thanks Rick van Hattem. - * Added lexicographical commands ZRANGEBYLEX, ZREMRANGEBYLEX, and ZLEXCOUNT + * Added support for SLOWLOG commands. Thanks Rick van Hattem. + * Added lexicographical commands ZRANGEBYLEX, ZREMRANGEBYLEX, and ZLEXCOUNT for sorted sets. - * Connection objects now support an optional argument, socket_read_size, + * Connection objects now support an optional argument, socket_read_size, indicating how much data to read during each socket.recv() call. After benchmarking, increased the default size to 64k, which dramatically improves performance when fetching large values, such as many results in a pipeline or a large (>1MB) string value. - * Improved the pack_command and send_packed_command functions to increase + * Improved the pack_command and send_packed_command functions to increase performance when sending large (>1MB) values. - * Sentinel Connections to master servers now detect when a READONLY error + * Sentinel Connections to master servers now detect when a READONLY error is encountered and disconnect themselves and all other active connections to the same master so that the new master can be discovered. - * Fixed Sentinel state parsing on Python 3. - * Added support for SENTINEL MONITOR, SENTINEL REMOVE, and SENTINEL SET + * Fixed Sentinel state parsing on Python 3. + * Added support for SENTINEL MONITOR, SENTINEL REMOVE, and SENTINEL SET commands. Thanks Greg Murphy. - * INFO output that doesn't follow the "key:value" format will now be + * INFO output that doesn't follow the "key:value" format will now be appended to a key named "__raw__" in the INFO dictionary. Thanks Pedro Larroy. - * The "vagrant" directory contains a complete vagrant environment for + * The "vagrant" directory contains a complete vagrant environment for redis-py developers. The environment runs a Redis master, a Redis slave, and 3 Sentinels. Future iterations of the test suite will incorporate more integration style tests, ensuring things like failover happen correctly. - * It's now possible to create connection pool instances from a URL. + * It's now possible to create connection pool instances from a URL. StrictRedis.from_url() now uses this feature to create a connection pool instance and use that when creating a new client instance. Thanks - https://github.com/chillipino - * When creating client instances or connection pool instances from an URL, + + * When creating client instances or connection pool instances from an URL, it's now possible to pass additional options to the connection pool with querystring arguments. - * Fixed a bug where some encodings (like utf-16) were unusable on Python 3 + * Fixed a bug where some encodings (like utf-16) were unusable on Python 3 as command names and literals would get encoded. - * Added an SSLConnection class that allows for secure connections through + * Added an SSLConnection class that allows for secure connections through stunnel or other means. Construct an SSL connection with the ssl=True option on client classes, using the rediss:// scheme from an URL, or by passing the SSLConnection class to a connection pool's - connection_class argument. Thanks https://github.com/oranagra. - * Added a socket_connect_timeout option to control how long to wait while + connection_class argument. Thanks . + * Added a socket_connect_timeout option to control how long to wait while establishing a TCP connection before timing out. This lets the client fail fast when attempting to connect to a downed server while keeping a more lenient timeout for all other socket operations. - * Added TCP Keep-alive support by passing use the socket_keepalive=True + * Added TCP Keep-alive support by passing use the socket_keepalive=True option. Finer grain control can be achieved using the socket_keepalive_options option which expects a dictionary with any of the keys (socket.TCP_KEEPIDLE, socket.TCP_KEEPCNT, socket.TCP_KEEPINTVL) and integers for values. Thanks Yossi Gottlieb. - * Added a `retry_on_timeout` option that controls how socket.timeout errors + * Added a `retry_on_timeout` option that controls how socket.timeout errors are handled. By default it is set to False and will cause the client to raise a TimeoutError anytime a socket.timeout is encountered. If `retry_on_timeout` is set to True, the client will retry a command that timed out once like other `socket.error`s. - * Completely refactored the Lock system. There is now a LuaLock class + * Completely refactored the Lock system. There is now a LuaLock class that's used when the Redis server is capable of running Lua scripts along with a fallback class for Redis servers < 2.6. The new locks fix several subtle race consider that the old lock could face. In additional, a new method, "extend" is available on lock instances that all a lock owner to extend the amount of time they have the lock for. Thanks to - Eli Finkelshteyn and https://github.com/chillipino for contributions. + Eli Finkelshteyn and for contributions. * 2.9.1 - * IPv6 support. Thanks https://github.com/amashinchi + * IPv6 support. Thanks * 2.9.0 - * Performance improvement for packing commands when using the PythonParser. + * Performance improvement for packing commands when using the PythonParser. Thanks Guillaume Viot. - * Executing an empty pipeline transaction no longer sends MULTI/EXEC to + * Executing an empty pipeline transaction no longer sends MULTI/EXEC to the server. Thanks EliFinkelshteyn. - * Errors when authenticating (incorrect password) and selecting a database + * Errors when authenticating (incorrect password) and selecting a database now close the socket. - * Full Sentinel support thanks to Vitja Makarov. Thanks! - * Better repr support for client and connection pool instances. Thanks + * Full Sentinel support thanks to Vitja Makarov. Thanks! + * Better repr support for client and connection pool instances. Thanks Mark Roberts. - * Error messages that the server sends to the client are now included + * Error messages that the server sends to the client are now included in the client error message. Thanks Sangjin Lim. - * Added the SCAN, SSCAN, HSCAN, and ZSCAN commands. Thanks Jingchao Hu. - * ResponseErrors generated by pipeline execution provide addition context + * Added the SCAN, SSCAN, HSCAN, and ZSCAN commands. Thanks Jingchao Hu. + * ResponseErrors generated by pipeline execution provide addition context including the position of the command in the pipeline and the actual command text generated the error. - * ConnectionPools now play nicer in threaded environments that fork. Thanks + * ConnectionPools now play nicer in threaded environments that fork. Thanks Christian Joergensen. * 2.8.0 - * redis-py should play better with gevent when a gevent Timeout is raised. + * redis-py should play better with gevent when a gevent Timeout is raised. Thanks leifkb. - * Added SENTINEL command. Thanks Anna Janackova. - * Fixed a bug where pipelines could potentially corrupt a connection + * Added SENTINEL command. Thanks Anna Janackova. + * Fixed a bug where pipelines could potentially corrupt a connection if the MULTI command generated a ResponseError. Thanks EliFinkelshteyn for the report. - * Connections now call socket.shutdown() prior to socket.close() to + * Connections now call socket.shutdown() prior to socket.close() to ensure communication ends immediately per the note at - https://docs.python.org/2/library/socket.html#socket.socket.close + Thanks to David Martin for pointing this out. - * Lock checks are now based on floats rather than ints. Thanks + * Lock checks are now based on floats rather than ints. Thanks Vitja Makarov. * 2.7.6 - * Added CONFIG RESETSTAT command. Thanks Yossi Gottlieb. - * Fixed a bug introduced in 2.7.3 that caused issues with script objects + * Added CONFIG RESETSTAT command. Thanks Yossi Gottlieb. + * Fixed a bug introduced in 2.7.3 that caused issues with script objects and pipelines. Thanks Carpentier Pierre-Francois. - * Converted redis-py's test suite to use the awesome py.test library. - * Fixed a bug introduced in 2.7.5 that prevented a ConnectionError from + * Converted redis-py's test suite to use the awesome py.test library. + * Fixed a bug introduced in 2.7.5 that prevented a ConnectionError from being raised when the Redis server is LOADING data. - * Added a BusyLoadingError exception that's raised when the Redis server + * Added a BusyLoadingError exception that's raised when the Redis server is starting up and not accepting commands yet. BusyLoadingError subclasses ConnectionError, which this state previously returned. Thanks Yossi Gottlieb. * 2.7.5 - * DEL, HDEL and ZREM commands now return the numbers of keys deleted + * DEL, HDEL and ZREM commands now return the numbers of keys deleted instead of just True/False. - * from_url now supports URIs with a port number. Thanks Aaron Westendorf. + * from_url now supports URIs with a port number. Thanks Aaron Westendorf. * 2.7.4 - * Added missing INCRBY method. Thanks Krzysztof Dorosz. - * SET now accepts the EX, PX, NX and XX options from Redis 2.6.12. These + * Added missing INCRBY method. Thanks Krzysztof Dorosz. + * SET now accepts the EX, PX, NX and XX options from Redis 2.6.12. These options will generate errors if these options are used when connected to a Redis server < 2.6.12. Thanks George Yoshida. * 2.7.3 - * Fixed a bug with BRPOPLPUSH and lists with empty strings. - * All empty except: clauses have been replaced to only catch Exception + * Fixed a bug with BRPOPLPUSH and lists with empty strings. + * All empty except: clauses have been replaced to only catch Exception subclasses. This prevents a KeyboardInterrupt from triggering exception handlers. Thanks Lucian Branescu Mihaila. - * All exceptions that are the result of redis server errors now share a + * All exceptions that are the result of redis server errors now share a command Exception subclass, ServerError. Thanks Matt Robenolt. - * Prevent DISCARD from being called if MULTI wasn't also called. Thanks + * Prevent DISCARD from being called if MULTI wasn't also called. Thanks Pete Aykroyd. - * SREM now returns an integer indicating the number of items removed from - the set. Thanks https://github.com/ronniekk. - * Fixed a bug with BGSAVE and BGREWRITEAOF response callbacks with Python3. + * SREM now returns an integer indicating the number of items removed from + the set. Thanks . + * Fixed a bug with BGSAVE and BGREWRITEAOF response callbacks with Python3. Thanks Nathan Wan. - * Added CLIENT GETNAME and CLIENT SETNAME commands. - Thanks https://github.com/bitterb. - * It's now possible to use len() on a pipeline instance to determine the + * Added CLIENT GETNAME and CLIENT SETNAME commands. + Thanks . + * It's now possible to use len() on a pipeline instance to determine the number of commands that will be executed. Thanks Jon Parise. - * Fixed a bug in INFO's parse routine with floating point numbers. Thanks + * Fixed a bug in INFO's parse routine with floating point numbers. Thanks Ali Onur Uyar. - * Fixed a bug with BITCOUNT to allow `start` and `end` to both be zero. + * Fixed a bug with BITCOUNT to allow `start` and `end` to both be zero. Thanks Tim Bart. - * The transaction() method now accepts a boolean keyword argument, + * The transaction() method now accepts a boolean keyword argument, value_from_callable. By default, or if False is passes, the transaction() method will return the value of the pipelines execution. Otherwise, it will return whatever func() returns. - * Python3 compatibility fix ensuring we're not already bytes(). Thanks + * Python3 compatibility fix ensuring we're not already bytes(). Thanks Salimane Adjao Moustapha. - * Added PSETEX. Thanks YAMAMOTO Takashi. - * Added a BlockingConnectionPool to limit the number of connections that + * Added PSETEX. Thanks YAMAMOTO Takashi. + * Added a BlockingConnectionPool to limit the number of connections that can be created. Thanks James Arthur. - * SORT now accepts a `groups` option that if specified, will return + * SORT now accepts a `groups` option that if specified, will return tuples of n-length, where n is the number of keys specified in the GET argument. This allows for convenient row-based iteration. Thanks Ionuț Arțăriși. * 2.7.2 - * Parse errors are now *always* raised on multi/exec pipelines, regardless + * Parse errors are now *always* raised on multi/exec pipelines, regardless of the `raise_on_error` flag. See - https://groups.google.com/forum/?hl=en&fromgroups=#!topic/redis-db/VUiEFT8U8U0 + for more info. * 2.7.1 - * Packaged tests with source code + * Packaged tests with source code * 2.7.0 - * Added BITOP and BITCOUNT commands. Thanks Mark Tozzi. - * Added the TIME command. Thanks Jason Knight. - * Added support for LUA scripting. Thanks to Angus Peart, Drew Smathers, + * Added BITOP and BITCOUNT commands. Thanks Mark Tozzi. + * Added the TIME command. Thanks Jason Knight. + * Added support for LUA scripting. Thanks to Angus Peart, Drew Smathers, Issac Kelly, Louis-Philippe Perron, Sean Bleier, Jeffrey Kaditz, and Dvir Volk for various patches and contributions to this feature. - * Changed the default error handling in pipelines. By default, the first + * Changed the default error handling in pipelines. By default, the first error in a pipeline will now be raised. A new parameter to the pipeline's execute, `raise_on_error`, can be set to False to keep the old behavior of embeedding the exception instances in the result. - * Fixed a bug with pipelines where parse errors won't corrupt the + * Fixed a bug with pipelines where parse errors won't corrupt the socket. - * Added the optional `number` argument to SRANDMEMBER for use with + * Added the optional `number` argument to SRANDMEMBER for use with Redis 2.6+ servers. - * Added PEXPIRE/PEXPIREAT/PTTL commands. Thanks Luper Rouch. - * Added INCRBYFLOAT/HINCRBYFLOAT commands. Thanks Nikita Uvarov. - * High precision floating point values won't lose their precision when + * Added PEXPIRE/PEXPIREAT/PTTL commands. Thanks Luper Rouch. + * Added INCRBYFLOAT/HINCRBYFLOAT commands. Thanks Nikita Uvarov. + * High precision floating point values won't lose their precision when being sent to the Redis server. Thanks Jason Oster and Oleg Pudeyev. - * Added CLIENT LIST/CLIENT KILL commands + * Added CLIENT LIST/CLIENT KILL commands * 2.6.2 - * `from_url` is now available as a classmethod on client classes. Thanks + * `from_url` is now available as a classmethod on client classes. Thanks Jon Parise for the patch. - * Fixed several encoding errors resulting from the Python 3.x support. + * Fixed several encoding errors resulting from the Python 3.x support. * 2.6.1 - * Python 3.x support! Big thanks to Alex Grönholm. - * Fixed a bug in the PythonParser's read_response that could hide an error + * Python 3.x support! Big thanks to Alex Grönholm. + * Fixed a bug in the PythonParser's read_response that could hide an error from the client (#251). * 2.6.0 - * Changed (p)subscribe and (p)unsubscribe to no longer return messages + * Changed (p)subscribe and (p)unsubscribe to no longer return messages indicating the channel was subscribed/unsubscribed to. These messages are available in the listen() loop instead. This is to prevent the following scenario: - * Client A is subscribed to "foo" - * Client B publishes message to "foo" - * Client A subscribes to channel "bar" at the same time. + * Client A is subscribed to "foo" + * Client B publishes message to "foo" + * Client A subscribes to channel "bar" at the same time. Prior to this change, the subscribe() call would return the published messages on "foo" rather than the subscription confirmation to "bar". - * Added support for GETRANGE, thanks Jean-Philippe Caruana - * A new setting "decode_responses" specifies whether return values from + * Added support for GETRANGE, thanks Jean-Philippe Caruana + * A new setting "decode_responses" specifies whether return values from Redis commands get decoded automatically using the client's charset value. Thanks to Frankie Dintino for the patch. * 2.4.13 - * redis.from_url() can take an URL representing a Redis connection string + * redis.from_url() can take an URL representing a Redis connection string and return a client object. Thanks Kenneth Reitz for the patch. * 2.4.12 - * ConnectionPool is now fork-safe. Thanks Josiah Carson for the patch. + * ConnectionPool is now fork-safe. Thanks Josiah Carson for the patch. * 2.4.11 - * AuthenticationError will now be correctly raised if an invalid password + * AuthenticationError will now be correctly raised if an invalid password is supplied. - * If Hiredis is unavailable, the HiredisParser will raise a RedisError + * If Hiredis is unavailable, the HiredisParser will raise a RedisError if selected manually. - * Made the INFO command more tolerant of Redis changes formatting. Fix + * Made the INFO command more tolerant of Redis changes formatting. Fix for #217. * 2.4.10 - * Buffer reads from socket in the PythonParser. Fix for a Windows-specific + * Buffer reads from socket in the PythonParser. Fix for a Windows-specific bug (#205). - * Added the OBJECT and DEBUG OBJECT commands. - * Added __del__ methods for classes that hold on to resources that need to + * Added the OBJECT and DEBUG OBJECT commands. + * Added __del__ methods for classes that hold on to resources that need to be cleaned up. This should prevent resource leakage when these objects leave scope due to misuse or unhandled exceptions. Thanks David Wolever for the suggestion. - * Added the ECHO command for completeness. - * Fixed a bug where attempting to subscribe to a PubSub channel of a Redis + * Added the ECHO command for completeness. + * Fixed a bug where attempting to subscribe to a PubSub channel of a Redis server that's down would blow out the stack. Fixes #179 and #195. Thanks Ovidiu Predescu for the test case. - * StrictRedis's TTL command now returns a -1 when querying a key with no + * StrictRedis's TTL command now returns a -1 when querying a key with no expiration. The Redis class continues to return None. - * ZADD and SADD now return integer values indicating the number of items + * ZADD and SADD now return integer values indicating the number of items added. Thanks Homer Strong. - * Renamed the base client class to StrictRedis, replacing ZADD and LREM in + * Renamed the base client class to StrictRedis, replacing ZADD and LREM in favor of their official argument order. The Redis class is now a subclass of StrictRedis, implementing the legacy redis-py implementations of ZADD and LREM. Docs have been updated to suggesting the use of StrictRedis. - * SETEX in StrictRedis is now compliant with official Redis SETEX command. + * SETEX in StrictRedis is now compliant with official Redis SETEX command. the name, value, time implementation moved to "Redis" for backwards compatibility. * 2.4.9 - * Removed socket retry logic in Connection. This is the responsibility of + * Removed socket retry logic in Connection. This is the responsibility of the caller to determine if the command is safe and can be retried. Thanks David Wolver. - * Added some extra guards around various types of exceptions being raised + * Added some extra guards around various types of exceptions being raised when sending or parsing data. Thanks David Wolver and Denis Bilenko. * 2.4.8 - * Imported with_statement from __future__ for Python 2.5 compatibility. + * Imported with_statement from __future__ for Python 2.5 compatibility. * 2.4.7 - * Fixed a bug where some connections were not getting released back to the + * Fixed a bug where some connections were not getting released back to the connection pool after pipeline execution. - * Pipelines can now be used as context managers. This is the preferred way + * Pipelines can now be used as context managers. This is the preferred way of use to ensure that connections get cleaned up properly. Thanks David Wolever. - * Added a convenience method called transaction() on the base Redis class. + * Added a convenience method called transaction() on the base Redis class. This method eliminates much of the boilerplate used when using pipelines to watch Redis keys. See the documentation for details on usage. * 2.4.6 - * Variadic arguments for SADD, SREM, ZREN, HDEL, LPUSH, and RPUSH. Thanks + * Variadic arguments for SADD, SREM, ZREN, HDEL, LPUSH, and RPUSH. Thanks Raphaël Vinot. - * (CRITICAL) Fixed an error in the Hiredis parser that occasionally caused + * (CRITICAL) Fixed an error in the Hiredis parser that occasionally caused the socket connection to become corrupted and unusable. This became noticeable once connection pools started to be used. - * ZRANGE, ZREVRANGE, ZRANGEBYSCORE, and ZREVRANGEBYSCORE now take an + * ZRANGE, ZREVRANGE, ZRANGEBYSCORE, and ZREVRANGEBYSCORE now take an additional optional argument, score_cast_func, which is a callable used to cast the score value in the return type. The default is float. - * Removed the PUBLISH method from the PubSub class. Connections that are + * Removed the PUBLISH method from the PubSub class. Connections that are [P]SUBSCRIBEd cannot issue PUBLISH commands, so it doesn't make sense to have it here. - * Pipelines now contain WATCH and UNWATCH. Calling WATCH or UNWATCH from + * Pipelines now contain WATCH and UNWATCH. Calling WATCH or UNWATCH from the base client class will result in a deprecation warning. After WATCHing one or more keys, the pipeline will be placed in immediate execution mode until UNWATCH or MULTI are called. Refer to the new pipeline docs in the README for more information. Thanks to David Wolever and Randall Leeds for greatly helping with this. * 2.4.5 - * The PythonParser now works better when reading zero length strings. + * The PythonParser now works better when reading zero length strings. * 2.4.4 - * Fixed a typo introduced in 2.4.3 + * Fixed a typo introduced in 2.4.3 * 2.4.3 - * Fixed a bug in the UnixDomainSocketConnection caused when trying to + * Fixed a bug in the UnixDomainSocketConnection caused when trying to form an error message after a socket error. * 2.4.2 - * Fixed a bug in pipeline that caused an exception while trying to + * Fixed a bug in pipeline that caused an exception while trying to reconnect after a connection timeout. * 2.4.1 - * Fixed a bug in the PythonParser if disconnect is called before connect. + * Fixed a bug in the PythonParser if disconnect is called before connect. * 2.4.0 - * WARNING: 2.4 contains several backwards incompatible changes. - * Completely refactored Connection objects. Moved much of the Redis + * WARNING: 2.4 contains several backwards incompatible changes. + * Completely refactored Connection objects. Moved much of the Redis protocol packing for requests here, and eliminated the nasty dependencies it had on the client to do AUTH and SELECT commands on connect. - * Connection objects now have a parser attribute. Parsers are responsible + * Connection objects now have a parser attribute. Parsers are responsible for reading data Redis sends. Two parsers ship with redis-py: a PythonParser and the HiRedis parser. redis-py will automatically use the HiRedis parser if you have the Python hiredis module installed, otherwise it will fall back to the PythonParser. You can force or the other, or even an external one by passing the `parser_class` argument to ConnectionPool. - * Added a UnixDomainSocketConnection for users wanting to talk to the Redis + * Added a UnixDomainSocketConnection for users wanting to talk to the Redis instance running on a local machine only. You can use this connection by passing it to the `connection_class` argument of the ConnectionPool. - * Connections no longer derive from threading.local. See threading.local + * Connections no longer derive from threading.local. See threading.local note below. - * ConnectionPool has been completely refactored. The ConnectionPool now + * ConnectionPool has been completely refactored. The ConnectionPool now maintains a list of connections. The redis-py client only hangs on to a ConnectionPool instance, calling get_connection() anytime it needs to send a command. When get_connection() is called, the command name and @@ -1030,83 +1050,83 @@ belong to and return a connection to it. ConnectionPool also implements disconnect() to force all connections in the pool to disconnect from the Redis server. - * redis-py no longer support the SELECT command. You can still connect to + * redis-py no longer support the SELECT command. You can still connect to a specific database by specifying it when instantiating a client instance or by creating a connection pool. If you need to talk to multiple databases within your application, you should use a separate client instance for each database you want to talk to. - * Completely refactored Publish/Subscribe support. The subscribe and listen + * Completely refactored Publish/Subscribe support. The subscribe and listen commands are no longer available on the redis-py Client class. Instead, the `pubsub` method returns an instance of the PubSub class which contains all publish/subscribe support. Note, you can still PUBLISH from the redis-py client class if you desire. - * Removed support for all previously deprecated commands or options. - * redis-py no longer uses threading.local in any way. Since the Client + * Removed support for all previously deprecated commands or options. + * redis-py no longer uses threading.local in any way. Since the Client class no longer holds on to a connection, it's no longer needed. You can now pass client instances between threads, and commands run on those threads will retrieve an available connection from the pool, use it and release it. It should now be trivial to use redis-py with eventlet or greenlet. - * ZADD now accepts pairs of value=score keyword arguments. This should help + * ZADD now accepts pairs of value=score keyword arguments. This should help resolve the long standing #72. The older value and score arguments have been deprecated in favor of the keyword argument style. - * Client instances now get their own copy of RESPONSE_CALLBACKS. The new + * Client instances now get their own copy of RESPONSE_CALLBACKS. The new set_response_callback method adds a user defined callback to the instance. - * Support Jython, fixing #97. Thanks to Adam Vandenberg for the patch. - * Using __getitem__ now properly raises a KeyError when the key is not + * Support Jython, fixing #97. Thanks to Adam Vandenberg for the patch. + * Using __getitem__ now properly raises a KeyError when the key is not found. Thanks Ionuț Arțăriși for the patch. - * Newer Redis versions return a LOADING message for some commands while + * Newer Redis versions return a LOADING message for some commands while the database is loading from disk during server start. This could cause problems with SELECT. We now force a socket disconnection prior to raising a ResponseError so subsequent connections have to reconnect and re-select the appropriate database. Thanks to Benjamin Anderson for finding this and fixing. * 2.2.4 - * WARNING: Potential backwards incompatible change - Changed order of + * WARNING: Potential backwards incompatible change - Changed order of parameters of ZREVRANGEBYSCORE to match those of the actual Redis command. This is only backwards-incompatible if you were passing max and min via keyword args. If passing by normal args, nothing in user code should have to change. Thanks Stéphane Angel for the fix. - * Fixed INFO to properly parse the Redis data correctly for both 2.2.x and + * Fixed INFO to properly parse the Redis data correctly for both 2.2.x and 2.3+. Thanks Stéphane Angel for the fix. - * Lock objects now store their timeout value as a float. This allows floats + * Lock objects now store their timeout value as a float. This allows floats to be used as timeout values. No changes to existing code required. - * WATCH now supports multiple keys. Thanks Rich Schumacher. - * Broke out some code that was Python 2.4 incompatible. redis-py should + * WATCH now supports multiple keys. Thanks Rich Schumacher. + * Broke out some code that was Python 2.4 incompatible. redis-py should now be usable on 2.4, but this hasn't actually been tested. Thanks Dan Colish for the patch. - * Optimized some code using izip and islice. Should have a pretty good + * Optimized some code using izip and islice. Should have a pretty good speed up on larger data sets. Thanks Dan Colish. - * Better error handling when submitting an empty mapping to HMSET. Thanks + * Better error handling when submitting an empty mapping to HMSET. Thanks Dan Colish. - * Subscription status is now reset after every (re)connection. + * Subscription status is now reset after every (re)connection. * 2.2.3 - * Added support for Hiredis. To use, simply "pip install hiredis" or + * Added support for Hiredis. To use, simply "pip install hiredis" or "easy_install hiredis". Thanks for Pieter Noordhuis for the hiredis-py bindings and the patch to redis-py. - * The connection class is chosen based on whether hiredis is installed + * The connection class is chosen based on whether hiredis is installed or not. To force the use of the PythonConnection, simply create your own ConnectionPool instance with the connection_class argument assigned to to PythonConnection class. - * Added missing command ZREVRANGEBYSCORE. Thanks Jay Baird for the patch. - * The INFO command should be parsed correctly on 2.2.x server versions + * Added missing command ZREVRANGEBYSCORE. Thanks Jay Baird for the patch. + * The INFO command should be parsed correctly on 2.2.x server versions and is backwards compatible with older versions. Thanks Brett Hoerner. * 2.2.2 - * Fixed a bug in ZREVRANK where retrieving the rank of a value not in + * Fixed a bug in ZREVRANK where retrieving the rank of a value not in the zset would raise an error. - * Fixed a bug in Connection.send where the errno import was getting + * Fixed a bug in Connection.send where the errno import was getting overwritten by a local variable. - * Fixed a bug in SLAVEOF when promoting an existing slave to a master. - * Reverted change of download URL back to redis-VERSION.tar.gz. 2.2.1's + * Fixed a bug in SLAVEOF when promoting an existing slave to a master. + * Reverted change of download URL back to redis-VERSION.tar.gz. 2.2.1's change of this actually broke Pypi for Pip installs. Sorry! * 2.2.1 - * Changed archive name to redis-py-VERSION.tar.gz to not conflict + * Changed archive name to redis-py-VERSION.tar.gz to not conflict with the Redis server archive. * 2.2.0 - * Implemented SLAVEOF - * Implemented CONFIG as config_get and config_set - * Implemented GETBIT/SETBIT - * Implemented BRPOPLPUSH - * Implemented STRLEN - * Implemented PERSIST - * Implemented SETRANGE + * Implemented SLAVEOF + * Implemented CONFIG as config_get and config_set + * Implemented GETBIT/SETBIT + * Implemented BRPOPLPUSH + * Implemented STRLEN + * Implemented PERSIST + * Implemented SETRANGE diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e31ec3491e..1081f4cb46 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,15 +2,15 @@ ## Introduction -First off, thank you for considering contributing to redis-py. We value -community contributions! +We appreciate your interest in considering contributing to redis-py. +Community contributions mean a lot to us. -## Contributions We Need +## Contributions we need -You may already know what you want to contribute \-- a fix for a bug you +You may already know how you'd like to contribute, whether it's a fix for a bug you encountered, or a new feature your team wants to use. -If you don't know what to contribute, keep an open mind! Improving +If you don't know where to start, consider improving documentation, bug triaging, and writing tutorials are all examples of helpful contributions that mean less work for you. @@ -38,8 +38,9 @@ Here's how to get started with your code contribution: a. python -m venv .venv b. source .venv/bin/activate c. pip install -r dev_requirements.txt + c. pip install -r requirements.txt -4. If you need a development environment, run `invoke devenv` +4. If you need a development environment, run `invoke devenv`. Note: this relies on docker-compose to build environments, and assumes that you have a version supporting [docker profiles](https://docs.docker.com/compose/profiles/). 5. While developing, make sure the tests pass by running `invoke tests` 6. If you like the change and think the project could use it, send a pull request @@ -59,7 +60,6 @@ can execute docker and its various commands. - Three sentinel Redis nodes - A redis cluster - An stunnel docker, fronting the master Redis node -- A Redis node, running unstable - the latest redis The replica node, is a replica of the master node, using the [leader-follower replication](https://redis.io/topics/replication) @@ -166,19 +166,19 @@ When filing an issue, make sure to answer these five questions: 4. What did you expect to see? 5. What did you see instead? -## How to Suggest a Feature or Enhancement +## Suggest a feature or enhancement If you'd like to contribute a new feature, make sure you check our issue list to see if someone has already proposed it. Work may already -be under way on the feature you want -- or we may have rejected a +be underway on the feature you want or we may have rejected a feature like it already. If you don't see anything, open a new issue that describes the feature you would like and how it should work. -## Code Review Process +## Code review process -The core team looks at Pull Requests on a regular basis. We will give -feedback as as soon as possible. After feedback, we expect a response +The core team regularly looks at pull requests. We will provide +feedback as soon as possible. After receiving our feedback, please respond within two weeks. After that time, we may close your PR if it isn't showing any activity. diff --git a/LICENSE b/LICENSE index 00aee10d6a..8509ccd678 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022, Redis, inc. +Copyright (c) 2022-2023, Redis, inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index c02483fe93..f8c3a78ae7 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,9 @@ The Python interface to the Redis key-value store. --------------------------------------------- -## Python Notice +**Note: ** redis-py 5.0 will be the last version of redis-py to support Python 3.7, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 5.1 will support Python 3.8+. -redis-py 4.3.x will be the last generation of redis-py to support python 3.6 as it has been [End of Life'd](https://www.python.org/dev/peps/pep-0494/#schedule-last-security-only-release). Async support was introduced in redis-py 4.2.x thanks to [aioredis](https://github.com/aio-libs/aioredis-py), which necessitates this change. We will continue to maintain 3.6 support as long as possible - but the plan is for redis-py version 4.4+ to officially remove 3.6. - ---------------------------- +--------------------------------------------- ## Installation @@ -37,11 +35,24 @@ For faster performance, install redis with hiredis support, this provides a comp By default, if hiredis >= 1.0 is available, redis-py will attempt to use it for response parsing. ``` bash -$ pip install redis[hiredis] +$ pip install "redis[hiredis]" ``` Looking for a high-level library to handle object mapping? See [redis-om-python](https://github.com/redis/redis-om-python)! +## Supported Redis Versions + +The most recent version of this library supports redis version [5.0](https://github.com/redis/redis/blob/5.0/00-RELEASENOTES), [6.0](https://github.com/redis/redis/blob/6.0/00-RELEASENOTES), [6.2](https://github.com/redis/redis/blob/6.2/00-RELEASENOTES), and [7.0](https://github.com/redis/redis/blob/7.0/00-RELEASENOTES). + +The table below highlights version compatibility of the most-recent library versions and redis versions. + +| Library version | Supported redis versions | +|-----------------|-------------------| +| 3.5.3 | <= 6.2 Family of releases | +| >= 4.5.0 | Version 5.0 to 7.0 | +| >= 5.0.0 | Version 5.0 to current | + + ## Usage ### Basic Example @@ -57,6 +68,15 @@ b'bar' The above code connects to localhost on port 6379, sets a value in Redis, and retrieves it. All responses are returned as bytes in Python, to receive decoded strings, set *decode_responses=True*. For this, and more connection options, see [these examples](https://redis.readthedocs.io/en/stable/examples.html). + +#### RESP3 Support +To enable support for RESP3, ensure you have at least version 5.0 of the client, and change your connection object to include *protocol=3* + +``` python +>>> import redis +>>> r = redis.Redis(host='localhost', port=6379, db=0, protocol=3) +``` + ### Connection Pools By default, redis-py uses a connection pool to manage connections. Each instance of a Redis class receives its own connection pool. You can however define your own [redis.ConnectionPool](https://redis.readthedocs.io/en/stable/connections.html#connection-pools). diff --git a/benchmarks/socket_read_size.py b/benchmarks/socket_read_size.py index 3427956ced..544c733178 100644 --- a/benchmarks/socket_read_size.py +++ b/benchmarks/socket_read_size.py @@ -1,12 +1,12 @@ from base import Benchmark -from redis.connection import HiredisParser, PythonParser +from redis.connection import PythonParser, _HiredisParser class SocketReadBenchmark(Benchmark): ARGUMENTS = ( - {"name": "parser", "values": [PythonParser, HiredisParser]}, + {"name": "parser", "values": [PythonParser, _HiredisParser]}, { "name": "value_size", "values": [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000], diff --git a/dev_requirements.txt b/dev_requirements.txt index 8285b0456f..cdb3774ab6 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,18 +1,18 @@ click==8.0.4 black==22.3.0 flake8==5.0.4 +flake8-isort==6.0.0 flynt~=0.69.0 -isort==5.10.1 mock==4.0.3 packaging>=20.4 pytest==7.2.0 -pytest-timeout==2.0.1 +pytest-timeout==2.1.0 pytest-asyncio>=0.20.2 tox==3.27.1 -tox-docker==3.1.0 invoke==1.7.3 pytest-cov>=4.0.0 vulture>=2.3.0 ujson>=4.2.0 wheel>=0.30.0 +urllib3<2 uvloop diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000..17d4b23977 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,109 @@ +--- + +version: "3.8" + +services: + + redis: + image: redis/redis-stack-server:edge + container_name: redis-standalone + ports: + - 6379:6379 + environment: + - "REDIS_ARGS=--enable-debug-command yes --enable-module-command yes" + profiles: + - standalone + - sentinel + - replica + - all + + replica: + image: redis/redis-stack-server:edge + container_name: redis-replica + depends_on: + - redis + environment: + - "REDIS_ARGS=--replicaof redis 6379" + ports: + - 6380:6379 + profiles: + - replica + - all + + cluster: + container_name: redis-cluster + build: + context: . + dockerfile: dockers/Dockerfile.cluster + ports: + - 16379:16379 + - 16380:16380 + - 16381:16381 + - 16382:16382 + - 16383:16383 + - 16384:16384 + volumes: + - "./dockers/cluster.redis.conf:/redis.conf:ro" + profiles: + - cluster + - all + + stunnel: + image: redisfab/stunnel:latest + depends_on: + - redis + ports: + - 6666:6666 + profiles: + - all + - standalone + - ssl + volumes: + - "./dockers/stunnel/conf:/etc/stunnel/conf.d:ro" + - "./dockers/stunnel/keys:/etc/stunnel/keys:ro" + + sentinel: + image: redis/redis-stack-server:edge + container_name: redis-sentinel + depends_on: + - redis + environment: + - "REDIS_ARGS=--port 26379" + entrypoint: "/opt/redis-stack/bin/redis-sentinel /redis.conf --port 26379" + ports: + - 26379:26379 + volumes: + - "./dockers/sentinel.conf:/redis.conf" + profiles: + - sentinel + - all + + sentinel2: + image: redis/redis-stack-server:edge + container_name: redis-sentinel2 + depends_on: + - redis + environment: + - "REDIS_ARGS=--port 26380" + entrypoint: "/opt/redis-stack/bin/redis-sentinel /redis.conf --port 26380" + ports: + - 26380:26380 + volumes: + - "./dockers/sentinel.conf:/redis.conf" + profiles: + - sentinel + - all + + sentinel3: + image: redis/redis-stack-server:edge + container_name: redis-sentinel3 + depends_on: + - redis + entrypoint: "/opt/redis-stack/bin/redis-sentinel /redis.conf --port 26381" + ports: + - 26381:26381 + volumes: + - "./dockers/sentinel.conf:/redis.conf" + profiles: + - sentinel + - all diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile deleted file mode 100644 index c76d15db36..0000000000 --- a/docker/base/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py:6.2.6 -FROM redis:6.2.6-buster - -CMD ["redis-server", "/redis.conf"] diff --git a/docker/base/Dockerfile.cluster b/docker/base/Dockerfile.cluster deleted file mode 100644 index 5c246dcf28..0000000000 --- a/docker/base/Dockerfile.cluster +++ /dev/null @@ -1,11 +0,0 @@ -# produces redisfab/redis-py-cluster:6.2.6 -FROM redis:6.2.6-buster - -COPY create_cluster.sh /create_cluster.sh -RUN chmod +x /create_cluster.sh - -EXPOSE 16379 16380 16381 16382 16383 16384 - -ENV START_PORT=16379 -ENV END_PORT=16384 -CMD /create_cluster.sh diff --git a/docker/base/Dockerfile.cluster4 b/docker/base/Dockerfile.cluster4 deleted file mode 100644 index 3158d6edd4..0000000000 --- a/docker/base/Dockerfile.cluster4 +++ /dev/null @@ -1,9 +0,0 @@ -# produces redisfab/redis-py-cluster:4.0 -FROM redis:4.0-buster - -COPY create_cluster4.sh /create_cluster4.sh -RUN chmod +x /create_cluster4.sh - -EXPOSE 16391 16392 16393 16394 16395 16396 - -CMD [ "/create_cluster4.sh"] \ No newline at end of file diff --git a/docker/base/Dockerfile.cluster5 b/docker/base/Dockerfile.cluster5 deleted file mode 100644 index 3becfc853a..0000000000 --- a/docker/base/Dockerfile.cluster5 +++ /dev/null @@ -1,9 +0,0 @@ -# produces redisfab/redis-py-cluster:5.0 -FROM redis:5.0-buster - -COPY create_cluster5.sh /create_cluster5.sh -RUN chmod +x /create_cluster5.sh - -EXPOSE 16385 16386 16387 16388 16389 16390 - -CMD [ "/create_cluster5.sh"] \ No newline at end of file diff --git a/docker/base/Dockerfile.redis4 b/docker/base/Dockerfile.redis4 deleted file mode 100644 index 7528ac1631..0000000000 --- a/docker/base/Dockerfile.redis4 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py:4.0 -FROM redis:4.0-buster - -CMD ["redis-server", "/redis.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.redis5 b/docker/base/Dockerfile.redis5 deleted file mode 100644 index 6bcbe20bfc..0000000000 --- a/docker/base/Dockerfile.redis5 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py:5.0 -FROM redis:5.0-buster - -CMD ["redis-server", "/redis.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.redismod_cluster b/docker/base/Dockerfile.redismod_cluster deleted file mode 100644 index 5b80e495fb..0000000000 --- a/docker/base/Dockerfile.redismod_cluster +++ /dev/null @@ -1,12 +0,0 @@ -# produces redisfab/redis-py-modcluster:6.2.6 -FROM redislabs/redismod:edge - -COPY create_redismod_cluster.sh /create_redismod_cluster.sh -RUN chmod +x /create_redismod_cluster.sh - -EXPOSE 46379 46380 46381 46382 46383 46384 - -ENV START_PORT=46379 -ENV END_PORT=46384 -ENTRYPOINT [] -CMD /create_redismod_cluster.sh diff --git a/docker/base/Dockerfile.sentinel b/docker/base/Dockerfile.sentinel deleted file mode 100644 index ef659e3004..0000000000 --- a/docker/base/Dockerfile.sentinel +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py-sentinel:6.2.6 -FROM redis:6.2.6-buster - -CMD ["redis-sentinel", "/sentinel.conf"] diff --git a/docker/base/Dockerfile.sentinel4 b/docker/base/Dockerfile.sentinel4 deleted file mode 100644 index 45bb03e88e..0000000000 --- a/docker/base/Dockerfile.sentinel4 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py-sentinel:4.0 -FROM redis:4.0-buster - -CMD ["redis-sentinel", "/sentinel.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.sentinel5 b/docker/base/Dockerfile.sentinel5 deleted file mode 100644 index 6958154e46..0000000000 --- a/docker/base/Dockerfile.sentinel5 +++ /dev/null @@ -1,4 +0,0 @@ -# produces redisfab/redis-py-sentinel:5.0 -FROM redis:5.0-buster - -CMD ["redis-sentinel", "/sentinel.conf"] \ No newline at end of file diff --git a/docker/base/Dockerfile.stunnel b/docker/base/Dockerfile.stunnel deleted file mode 100644 index bf4510907c..0000000000 --- a/docker/base/Dockerfile.stunnel +++ /dev/null @@ -1,11 +0,0 @@ -# produces redisfab/stunnel:latest -FROM ubuntu:18.04 - -RUN apt-get update -qq --fix-missing -RUN apt-get upgrade -qqy -RUN apt install -qqy stunnel -RUN mkdir -p /etc/stunnel/conf.d -RUN echo "foreground = yes\ninclude = /etc/stunnel/conf.d" > /etc/stunnel/stunnel.conf -RUN chown -R root:root /etc/stunnel/ - -CMD ["/usr/bin/stunnel"] diff --git a/docker/base/Dockerfile.unstable b/docker/base/Dockerfile.unstable deleted file mode 100644 index ab5b7fc6fb..0000000000 --- a/docker/base/Dockerfile.unstable +++ /dev/null @@ -1,18 +0,0 @@ -# produces redisfab/redis-py:unstable -FROM ubuntu:bionic as builder -RUN apt-get update -RUN apt-get upgrade -y -RUN apt-get install -y build-essential git -RUN mkdir /build -WORKDIR /build -RUN git clone https://github.com/redis/redis -WORKDIR /build/redis -RUN make - -FROM ubuntu:bionic as runner -COPY --from=builder /build/redis/src/redis-server /usr/bin/redis-server -COPY --from=builder /build/redis/src/redis-cli /usr/bin/redis-cli -COPY --from=builder /build/redis/src/redis-sentinel /usr/bin/redis-sentinel - -EXPOSE 6379 -CMD ["redis-server", "/redis.conf"] diff --git a/docker/base/Dockerfile.unstable_cluster b/docker/base/Dockerfile.unstable_cluster deleted file mode 100644 index 2e3ed55371..0000000000 --- a/docker/base/Dockerfile.unstable_cluster +++ /dev/null @@ -1,11 +0,0 @@ -# produces redisfab/redis-py-cluster:6.2.6 -FROM redisfab/redis-py:unstable-bionic - -COPY create_cluster.sh /create_cluster.sh -RUN chmod +x /create_cluster.sh - -EXPOSE 6372 6373 6374 6375 6376 6377 - -ENV START_PORT=6372 -ENV END_PORT=6377 -CMD ["/create_cluster.sh"] diff --git a/docker/base/Dockerfile.unstable_sentinel b/docker/base/Dockerfile.unstable_sentinel deleted file mode 100644 index fe6d062de8..0000000000 --- a/docker/base/Dockerfile.unstable_sentinel +++ /dev/null @@ -1,17 +0,0 @@ -# produces redisfab/redis-py-sentinel:unstable -FROM ubuntu:bionic as builder -RUN apt-get update -RUN apt-get upgrade -y -RUN apt-get install -y build-essential git -RUN mkdir /build -WORKDIR /build -RUN git clone https://github.com/redis/redis -WORKDIR /build/redis -RUN make - -FROM ubuntu:bionic as runner -COPY --from=builder /build/redis/src/redis-server /usr/bin/redis-server -COPY --from=builder /build/redis/src/redis-cli /usr/bin/redis-cli -COPY --from=builder /build/redis/src/redis-sentinel /usr/bin/redis-sentinel - -CMD ["redis-sentinel", "/sentinel.conf"] diff --git a/docker/base/README.md b/docker/base/README.md deleted file mode 100644 index a2f26a8106..0000000000 --- a/docker/base/README.md +++ /dev/null @@ -1 +0,0 @@ -Dockers in this folder are built, and uploaded to the redisfab dockerhub store. diff --git a/docker/base/create_cluster4.sh b/docker/base/create_cluster4.sh deleted file mode 100755 index a39da58784..0000000000 --- a/docker/base/create_cluster4.sh +++ /dev/null @@ -1,26 +0,0 @@ -#! /bin/bash -mkdir -p /nodes -touch /nodes/nodemap -for PORT in $(seq 16391 16396); do - mkdir -p /nodes/$PORT - if [[ -e /redis.conf ]]; then - cp /redis.conf /nodes/$PORT/redis.conf - else - touch /nodes/$PORT/redis.conf - fi - cat << EOF >> /nodes/$PORT/redis.conf -port ${PORT} -cluster-enabled yes -daemonize yes -logfile /redis.log -dir /nodes/$PORT -EOF - redis-server /nodes/$PORT/redis.conf - if [ $? -ne 0 ]; then - echo "Redis failed to start, exiting." - exit 3 - fi - echo 127.0.0.1:$PORT >> /nodes/nodemap -done -echo yes | redis-cli --cluster create $(seq -f 127.0.0.1:%g 16391 16396) --cluster-replicas 1 -tail -f /redis.log \ No newline at end of file diff --git a/docker/base/create_cluster5.sh b/docker/base/create_cluster5.sh deleted file mode 100755 index 0c63d8e910..0000000000 --- a/docker/base/create_cluster5.sh +++ /dev/null @@ -1,26 +0,0 @@ -#! /bin/bash -mkdir -p /nodes -touch /nodes/nodemap -for PORT in $(seq 16385 16390); do - mkdir -p /nodes/$PORT - if [[ -e /redis.conf ]]; then - cp /redis.conf /nodes/$PORT/redis.conf - else - touch /nodes/$PORT/redis.conf - fi - cat << EOF >> /nodes/$PORT/redis.conf -port ${PORT} -cluster-enabled yes -daemonize yes -logfile /redis.log -dir /nodes/$PORT -EOF - redis-server /nodes/$PORT/redis.conf - if [ $? -ne 0 ]; then - echo "Redis failed to start, exiting." - exit 3 - fi - echo 127.0.0.1:$PORT >> /nodes/nodemap -done -echo yes | redis-cli --cluster create $(seq -f 127.0.0.1:%g 16385 16390) --cluster-replicas 1 -tail -f /redis.log \ No newline at end of file diff --git a/docker/base/create_redismod_cluster.sh b/docker/base/create_redismod_cluster.sh deleted file mode 100755 index 20443a4c42..0000000000 --- a/docker/base/create_redismod_cluster.sh +++ /dev/null @@ -1,46 +0,0 @@ -#! /bin/bash - -mkdir -p /nodes -touch /nodes/nodemap -if [ -z ${START_PORT} ]; then - START_PORT=46379 -fi -if [ -z ${END_PORT} ]; then - END_PORT=46384 -fi -if [ ! -z "$3" ]; then - START_PORT=$2 - START_PORT=$3 -fi -echo "STARTING: ${START_PORT}" -echo "ENDING: ${END_PORT}" - -for PORT in `seq ${START_PORT} ${END_PORT}`; do - mkdir -p /nodes/$PORT - if [[ -e /redis.conf ]]; then - cp /redis.conf /nodes/$PORT/redis.conf - else - touch /nodes/$PORT/redis.conf - fi - cat << EOF >> /nodes/$PORT/redis.conf -port ${PORT} -cluster-enabled yes -daemonize yes -logfile /redis.log -dir /nodes/$PORT -EOF - - set -x - redis-server /nodes/$PORT/redis.conf - if [ $? -ne 0 ]; then - echo "Redis failed to start, exiting." - continue - fi - echo 127.0.0.1:$PORT >> /nodes/nodemap -done -if [ -z "${REDIS_PASSWORD}" ]; then - echo yes | redis-cli --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 -else - echo yes | redis-cli -a ${REDIS_PASSWORD} --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 -fi -tail -f /redis.log diff --git a/docker/cluster/redis.conf b/docker/cluster/redis.conf deleted file mode 100644 index dff658c79b..0000000000 --- a/docker/cluster/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -# Redis Cluster config file will be shared across all nodes. -# Do not change the following configurations that are already set: -# port, cluster-enabled, daemonize, logfile, dir diff --git a/docker/redis4/master/redis.conf b/docker/redis4/master/redis.conf deleted file mode 100644 index b7ed0ebf00..0000000000 --- a/docker/redis4/master/redis.conf +++ /dev/null @@ -1,2 +0,0 @@ -port 6381 -save "" diff --git a/docker/redis4/sentinel/sentinel_1.conf b/docker/redis4/sentinel/sentinel_1.conf deleted file mode 100644 index cfee17c051..0000000000 --- a/docker/redis4/sentinel/sentinel_1.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26385 - -sentinel monitor redis-py-test 127.0.0.1 6381 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis4/sentinel/sentinel_2.conf b/docker/redis4/sentinel/sentinel_2.conf deleted file mode 100644 index 68d930aea8..0000000000 --- a/docker/redis4/sentinel/sentinel_2.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26386 - -sentinel monitor redis-py-test 127.0.0.1 6381 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis4/sentinel/sentinel_3.conf b/docker/redis4/sentinel/sentinel_3.conf deleted file mode 100644 index 60abf65c9b..0000000000 --- a/docker/redis4/sentinel/sentinel_3.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26387 - -sentinel monitor redis-py-test 127.0.0.1 6381 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis5/master/redis.conf b/docker/redis5/master/redis.conf deleted file mode 100644 index e479c48b28..0000000000 --- a/docker/redis5/master/redis.conf +++ /dev/null @@ -1,2 +0,0 @@ -port 6382 -save "" diff --git a/docker/redis5/replica/redis.conf b/docker/redis5/replica/redis.conf deleted file mode 100644 index a2dc9e0945..0000000000 --- a/docker/redis5/replica/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -port 6383 -save "" -replicaof master 6382 diff --git a/docker/redis5/sentinel/sentinel_1.conf b/docker/redis5/sentinel/sentinel_1.conf deleted file mode 100644 index c748a0ba72..0000000000 --- a/docker/redis5/sentinel/sentinel_1.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26382 - -sentinel monitor redis-py-test 127.0.0.1 6382 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis5/sentinel/sentinel_2.conf b/docker/redis5/sentinel/sentinel_2.conf deleted file mode 100644 index 0a50c9a623..0000000000 --- a/docker/redis5/sentinel/sentinel_2.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26383 - -sentinel monitor redis-py-test 127.0.0.1 6382 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis5/sentinel/sentinel_3.conf b/docker/redis5/sentinel/sentinel_3.conf deleted file mode 100644 index a0e350ba0f..0000000000 --- a/docker/redis5/sentinel/sentinel_3.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26384 - -sentinel monitor redis-py-test 127.0.0.1 6383 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/redis6.2/master/redis.conf b/docker/redis6.2/master/redis.conf deleted file mode 100644 index 15a31b5a38..0000000000 --- a/docker/redis6.2/master/redis.conf +++ /dev/null @@ -1,2 +0,0 @@ -port 6379 -save "" diff --git a/docker/redis6.2/replica/redis.conf b/docker/redis6.2/replica/redis.conf deleted file mode 100644 index a76d402c5e..0000000000 --- a/docker/redis6.2/replica/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -port 6380 -save "" -replicaof master 6379 diff --git a/docker/redis6.2/sentinel/sentinel_2.conf b/docker/redis6.2/sentinel/sentinel_2.conf deleted file mode 100644 index 955621b872..0000000000 --- a/docker/redis6.2/sentinel/sentinel_2.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26380 - -sentinel monitor redis-py-test 127.0.0.1 6379 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis6.2/sentinel/sentinel_3.conf b/docker/redis6.2/sentinel/sentinel_3.conf deleted file mode 100644 index 62c40512f1..0000000000 --- a/docker/redis6.2/sentinel/sentinel_3.conf +++ /dev/null @@ -1,6 +0,0 @@ -port 26381 - -sentinel monitor redis-py-test 127.0.0.1 6379 2 -sentinel down-after-milliseconds redis-py-test 5000 -sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 diff --git a/docker/redis7/master/redis.conf b/docker/redis7/master/redis.conf deleted file mode 100644 index ef57c1fe99..0000000000 --- a/docker/redis7/master/redis.conf +++ /dev/null @@ -1,4 +0,0 @@ -port 6379 -save "" -enable-debug-command yes -enable-module-command yes \ No newline at end of file diff --git a/docker/redismod_cluster/redis.conf b/docker/redismod_cluster/redis.conf deleted file mode 100644 index 48f06668a8..0000000000 --- a/docker/redismod_cluster/redis.conf +++ /dev/null @@ -1,8 +0,0 @@ -loadmodule /usr/lib/redis/modules/redisai.so -loadmodule /usr/lib/redis/modules/redisearch.so -loadmodule /usr/lib/redis/modules/redisgraph.so -loadmodule /usr/lib/redis/modules/redistimeseries.so -loadmodule /usr/lib/redis/modules/rejson.so -loadmodule /usr/lib/redis/modules/redisbloom.so -loadmodule /var/opt/redislabs/lib/modules/redisgears.so Plugin /var/opt/redislabs/modules/rg/plugin/gears_python.so Plugin /var/opt/redislabs/modules/rg/plugin/gears_jvm.so JvmOptions -Djava.class.path=/var/opt/redislabs/modules/rg/gear_runtime-jar-with-dependencies.jar JvmPath /var/opt/redislabs/modules/rg/OpenJDK/jdk-11.0.9.1+1/ - diff --git a/docker/unstable/redis.conf b/docker/unstable/redis.conf deleted file mode 100644 index 93a55cf3b3..0000000000 --- a/docker/unstable/redis.conf +++ /dev/null @@ -1,3 +0,0 @@ -port 6378 -protected-mode no -save "" diff --git a/docker/unstable_cluster/redis.conf b/docker/unstable_cluster/redis.conf deleted file mode 100644 index f307a63757..0000000000 --- a/docker/unstable_cluster/redis.conf +++ /dev/null @@ -1,4 +0,0 @@ -# Redis Cluster config file will be shared across all nodes. -# Do not change the following configurations that are already set: -# port, cluster-enabled, daemonize, logfile, dir -protected-mode no diff --git a/dockers/Dockerfile.cluster b/dockers/Dockerfile.cluster new file mode 100644 index 0000000000..3a0d73415e --- /dev/null +++ b/dockers/Dockerfile.cluster @@ -0,0 +1,7 @@ +FROM redis/redis-stack-server:edge as rss + +COPY dockers/create_cluster.sh /create_cluster.sh +RUN ls -R /opt/redis-stack +RUN chmod a+x /create_cluster.sh + +ENTRYPOINT [ "/create_cluster.sh"] diff --git a/dockers/cluster.redis.conf b/dockers/cluster.redis.conf new file mode 100644 index 0000000000..d4de46fbed --- /dev/null +++ b/dockers/cluster.redis.conf @@ -0,0 +1,8 @@ +protected-mode no +enable-debug-command yes +loadmodule /opt/redis-stack/lib/redisearch.so +loadmodule /opt/redis-stack/lib/redisgraph.so +loadmodule /opt/redis-stack/lib/redistimeseries.so +loadmodule /opt/redis-stack/lib/rejson.so +loadmodule /opt/redis-stack/lib/redisbloom.so +loadmodule /opt/redis-stack/lib/redisgears.so v8-plugin-path /opt/redis-stack/lib/libredisgears_v8_plugin.so diff --git a/docker/base/create_cluster.sh b/dockers/create_cluster.sh old mode 100755 new mode 100644 similarity index 75% rename from docker/base/create_cluster.sh rename to dockers/create_cluster.sh index fcb1b1cd8d..da9a0cb606 --- a/docker/base/create_cluster.sh +++ b/dockers/create_cluster.sh @@ -31,7 +31,8 @@ dir /nodes/$PORT EOF set -x - redis-server /nodes/$PORT/redis.conf + /opt/redis-stack/bin/redis-server /nodes/$PORT/redis.conf + sleep 1 if [ $? -ne 0 ]; then echo "Redis failed to start, exiting." continue @@ -39,8 +40,8 @@ EOF echo 127.0.0.1:$PORT >> /nodes/nodemap done if [ -z "${REDIS_PASSWORD}" ]; then - echo yes | redis-cli --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 + echo yes | /opt/redis-stack/bin/redis-cli --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 else - echo yes | redis-cli -a ${REDIS_PASSWORD} --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 + echo yes | opt/redis-stack/bin/redis-cli -a ${REDIS_PASSWORD} --cluster create `seq -f 127.0.0.1:%g ${START_PORT} ${END_PORT}` --cluster-replicas 1 fi tail -f /redis.log diff --git a/docker/redis6.2/sentinel/sentinel_1.conf b/dockers/sentinel.conf similarity index 73% rename from docker/redis6.2/sentinel/sentinel_1.conf rename to dockers/sentinel.conf index bd2d830af3..1a33f53344 100644 --- a/docker/redis6.2/sentinel/sentinel_1.conf +++ b/dockers/sentinel.conf @@ -1,6 +1,4 @@ -port 26379 - sentinel monitor redis-py-test 127.0.0.1 6379 2 sentinel down-after-milliseconds redis-py-test 5000 sentinel failover-timeout redis-py-test 60000 -sentinel parallel-syncs redis-py-test 1 +sentinel parallel-syncs redis-py-test 1 \ No newline at end of file diff --git a/docker/stunnel/README b/dockers/stunnel/README similarity index 100% rename from docker/stunnel/README rename to dockers/stunnel/README diff --git a/docker/stunnel/conf/redis.conf b/dockers/stunnel/conf/redis.conf similarity index 83% rename from docker/stunnel/conf/redis.conf rename to dockers/stunnel/conf/redis.conf index 84f6d40133..a150d8b011 100644 --- a/docker/stunnel/conf/redis.conf +++ b/dockers/stunnel/conf/redis.conf @@ -1,6 +1,6 @@ [redis] accept = 6666 -connect = master:6379 +connect = redis:6379 cert = /etc/stunnel/keys/server-cert.pem key = /etc/stunnel/keys/server-key.pem verify = 0 diff --git a/docker/stunnel/create_certs.sh b/dockers/stunnel/create_certs.sh similarity index 100% rename from docker/stunnel/create_certs.sh rename to dockers/stunnel/create_certs.sh diff --git a/docker/stunnel/keys/ca-cert.pem b/dockers/stunnel/keys/ca-cert.pem similarity index 100% rename from docker/stunnel/keys/ca-cert.pem rename to dockers/stunnel/keys/ca-cert.pem diff --git a/docker/stunnel/keys/ca-key.pem b/dockers/stunnel/keys/ca-key.pem similarity index 100% rename from docker/stunnel/keys/ca-key.pem rename to dockers/stunnel/keys/ca-key.pem diff --git a/docker/stunnel/keys/client-cert.pem b/dockers/stunnel/keys/client-cert.pem similarity index 100% rename from docker/stunnel/keys/client-cert.pem rename to dockers/stunnel/keys/client-cert.pem diff --git a/docker/stunnel/keys/client-key.pem b/dockers/stunnel/keys/client-key.pem similarity index 100% rename from docker/stunnel/keys/client-key.pem rename to dockers/stunnel/keys/client-key.pem diff --git a/docker/stunnel/keys/client-req.pem b/dockers/stunnel/keys/client-req.pem similarity index 100% rename from docker/stunnel/keys/client-req.pem rename to dockers/stunnel/keys/client-req.pem diff --git a/docker/stunnel/keys/server-cert.pem b/dockers/stunnel/keys/server-cert.pem similarity index 100% rename from docker/stunnel/keys/server-cert.pem rename to dockers/stunnel/keys/server-cert.pem diff --git a/docker/stunnel/keys/server-key.pem b/dockers/stunnel/keys/server-key.pem similarity index 100% rename from docker/stunnel/keys/server-key.pem rename to dockers/stunnel/keys/server-key.pem diff --git a/docker/stunnel/keys/server-req.pem b/dockers/stunnel/keys/server-req.pem similarity index 100% rename from docker/stunnel/keys/server-req.pem rename to dockers/stunnel/keys/server-req.pem diff --git a/docs/advanced_features.rst b/docs/advanced_features.rst index 4ad922fe72..fd29d2f684 100644 --- a/docs/advanced_features.rst +++ b/docs/advanced_features.rst @@ -37,7 +37,7 @@ the client and server. Pipelines are quite simple to use: -.. code:: pycon +.. code:: python >>> r = redis.Redis(...) >>> r.set('bing', 'baz') @@ -54,7 +54,7 @@ Pipelines are quite simple to use: For ease of use, all commands being buffered into the pipeline return the pipeline object itself. Therefore calls can be chained like: -.. code:: pycon +.. code:: python >>> pipe.set('foo', 'bar').sadd('faz', 'baz').incr('auto_number').execute() [True, True, 6] @@ -64,7 +64,7 @@ executed atomically as a group. This happens by default. If you want to disable the atomic nature of a pipeline but still want to buffer commands, you can turn off transactions. -.. code:: pycon +.. code:: python >>> pipe = r.pipeline(transaction=False) @@ -84,7 +84,7 @@ prior the execution of that transaction, the entire transaction will be canceled and a WatchError will be raised. To implement our own client-side INCR command, we could do something like this: -.. code:: pycon +.. code:: python >>> with r.pipeline() as pipe: ... while True: @@ -117,7 +117,7 @@ Pipeline is used as a context manager (as in the example above) reset() will be called automatically. Of course you can do this the manual way by explicitly calling reset(): -.. code:: pycon +.. code:: python >>> pipe = r.pipeline() >>> while True: @@ -137,7 +137,7 @@ that should expect a single parameter, a pipeline object, and any number of keys to be WATCHed. Our client-side INCR command above can be written like this, which is much easier to read: -.. code:: pycon +.. code:: python >>> def client_side_incr(pipe): ... current_value = pipe.get('OUR-SEQUENCE-KEY') @@ -162,10 +162,10 @@ instance will wait for all the nodes to respond before returning the result to the caller. Command responses are returned as a list sorted in the same order in which they were sent. Pipelines can be used to dramatically increase the throughput of Redis Cluster by significantly -reducing the the number of network round trips between the client and +reducing the number of network round trips between the client and the server. -.. code:: pycon +.. code:: python >>> with rc.pipeline() as pipe: ... pipe.set('foo', 'value1') @@ -198,7 +198,7 @@ Publish / Subscribe redis-py includes a PubSub object that subscribes to channels and listens for new messages. Creating a PubSub object is easy. -.. code:: pycon +.. code:: python >>> r = redis.Redis(...) >>> p = r.pubsub() @@ -206,7 +206,7 @@ listens for new messages. Creating a PubSub object is easy. Once a PubSub instance is created, channels and patterns can be subscribed to. -.. code:: pycon +.. code:: python >>> p.subscribe('my-first-channel', 'my-second-channel', ...) >>> p.psubscribe('my-*', ...) @@ -215,7 +215,7 @@ The PubSub instance is now subscribed to those channels/patterns. The subscription confirmations can be seen by reading messages from the PubSub instance. -.. code:: pycon +.. code:: python >>> p.get_message() {'pattern': None, 'type': 'subscribe', 'channel': b'my-second-channel', 'data': 1} @@ -240,7 +240,7 @@ following keys. Let's send a message now. -.. code:: pycon +.. code:: python # the publish method returns the number matching channel and pattern # subscriptions. 'my-first-channel' matches both the 'my-first-channel' @@ -256,7 +256,7 @@ Let's send a message now. Unsubscribing works just like subscribing. If no arguments are passed to [p]unsubscribe, all channels or patterns will be unsubscribed from. -.. code:: pycon +.. code:: python >>> p.unsubscribe() >>> p.punsubscribe('my-*') @@ -279,7 +279,7 @@ the message dictionary is created and passed to the message handler. In this case, a None value is returned from get_message() since the message was already handled. -.. code:: pycon +.. code:: python >>> def my_handler(message): ... print('MY HANDLER: ', message['data']) @@ -305,7 +305,7 @@ passing ignore_subscribe_messages=True to r.pubsub(). This will cause all subscribe/unsubscribe messages to be read, but they won't bubble up to your application. -.. code:: pycon +.. code:: python >>> p = r.pubsub(ignore_subscribe_messages=True) >>> p.subscribe('my-channel') @@ -325,7 +325,7 @@ to a message handler. If there's no data to be read, get_message() will immediately return None. This makes it trivial to integrate into an existing event loop inside your application. -.. code:: pycon +.. code:: python >>> while True: >>> message = p.get_message() @@ -339,7 +339,7 @@ your application doesn't need to do anything else but receive and act on messages received from redis, listen() is an easy way to get up an running. -.. code:: pycon +.. code:: python >>> for message in p.listen(): ... # do something with the message @@ -360,7 +360,7 @@ handlers. Therefore, redis-py prevents you from calling run_in_thread() if you're subscribed to patterns or channels that don't have message handlers attached. -.. code:: pycon +.. code:: python >>> p.subscribe(**{'my-channel': my_handler}) >>> thread = p.run_in_thread(sleep_time=0.001) @@ -374,7 +374,7 @@ appropriately. The exception handler will take as arguments the exception itself, the pubsub object, and the worker thread returned by run_in_thread. -.. code:: pycon +.. code:: python >>> p.subscribe(**{'my-channel': my_handler}) >>> def exception_handler(ex, pubsub, thread): @@ -401,7 +401,7 @@ when reconnecting. Messages that were published while the client was disconnected cannot be delivered. When you're finished with a PubSub object, call its .close() method to shutdown the connection. -.. code:: pycon +.. code:: python >>> p = r.pubsub() >>> ... @@ -410,7 +410,7 @@ object, call its .close() method to shutdown the connection. The PUBSUB set of subcommands CHANNELS, NUMSUB and NUMPAT are also supported: -.. code:: pycon +.. code:: python >>> r.pubsub_channels() [b'foo', b'bar'] @@ -421,6 +421,38 @@ supported: >>> r.pubsub_numpat() 1204 +Sharded pubsub +~~~~~~~~~~~~~~ + +`Sharded pubsub `_ is a feature introduced with Redis 7.0, and fully supported by redis-py as of 5.0. It helps scale the usage of pub/sub in cluster mode, by having the cluster shard messages to nodes that own a slot for a shard channel. Here, the cluster ensures the published shard messages are forwarded to the appropriate nodes. Clients subscribe to a channel by connecting to either the master responsible for the slot, or any of its replicas. + +This makes use of the `SSUBSCRIBE `_ and `SPUBLISH `_ commands within Redis. + +The following, is a simplified example: + +.. code:: python + + >>> from redis.cluster import RedisCluster, ClusterNode + >>> r = RedisCluster(startup_nodes=[ClusterNode('localhost', 6379), ClusterNode('localhost', 6380)]) + >>> p = r.pubsub() + >>> p.ssubscribe('foo') + >>> # assume someone sends a message along the channel via a publish + >>> message = p.get_sharded_message() + +Similarly, the same process can be used to acquire sharded pubsub messages, that have already been sent to a specific node, by passing the node to get_sharded_message: + +.. code:: python + + >>> from redis.cluster import RedisCluster, ClusterNode + >>> first_node = ClusterNode['localhost', 6379] + >>> second_node = ClusterNode['localhost', 6380] + >>> r = RedisCluster(startup_nodes=[first_node, second_node]) + >>> p = r.pubsub() + >>> p.ssubscribe('foo') + >>> # assume someone sends a message along the channel via a publish + >>> message = p.get_sharded_message(target_node=second_node) + + Monitor ~~~~~~~ @@ -428,7 +460,7 @@ redis-py includes a Monitor object that streams every command processed by the Redis server. Use listen() on the Monitor object to block until a command is received. -.. code:: pycon +.. code:: python >>> r = redis.Redis(...) >>> with r.monitor() as m: diff --git a/docs/clustering.rst b/docs/clustering.rst index 34cb7f1f69..9b4dee1c9f 100644 --- a/docs/clustering.rst +++ b/docs/clustering.rst @@ -26,7 +26,7 @@ cluster instance can be created: - Using ‘host’ and ‘port’ arguments: -.. code:: pycon +.. code:: python >>> from redis.cluster import RedisCluster as Redis >>> rc = Redis(host='localhost', port=6379) @@ -35,14 +35,14 @@ cluster instance can be created: - Using the Redis URL specification: -.. code:: pycon +.. code:: python >>> from redis.cluster import RedisCluster as Redis >>> rc = Redis.from_url("redis://localhost:6379/0") - Directly, via the ClusterNode class: -.. code:: pycon +.. code:: python >>> from redis.cluster import RedisCluster as Redis >>> from redis.cluster import ClusterNode @@ -77,7 +77,7 @@ you can change it using the ‘set_default_node’ method. The ‘target_nodes’ parameter is explained in the following section, ‘Specifying Target Nodes’. -.. code:: pycon +.. code:: python >>> # target-nodes: the node that holds 'foo1's key slot >>> rc.set('foo1', 'bar1') @@ -105,7 +105,7 @@ topology of the cluster changes during the execution of a command, the client will be able to resolve the nodes flag again with the new topology and attempt to retry executing the command. -.. code:: pycon +.. code:: python >>> from redis.cluster import RedisCluster as Redis >>> # run cluster-meet command on all of the cluster's nodes @@ -127,7 +127,7 @@ topology changes, a retry attempt will not be made, since the passed target node/s may no longer be valid, and the relevant cluster or connection error will be returned. -.. code:: pycon +.. code:: python >>> node = rc.get_node('localhost', 6379) >>> # Get the keys only for that specific node @@ -140,7 +140,7 @@ In addition, the RedisCluster instance can query the Redis instance of a specific node and execute commands on that node directly. The Redis client, however, does not handle cluster failures and retries. -.. code:: pycon +.. code:: python >>> cluster_node = rc.get_node(host='localhost', port=6379) >>> print(cluster_node) @@ -170,7 +170,7 @@ to the relevant slots, sending the commands to the slots’ node owners. Non-atomic operations batch the keys according to their hash value, and then each batch is sent separately to the slot’s owner. -.. code:: pycon +.. code:: python # Atomic operations can be used when all keys are mapped to the same slot >>> rc.mset({'{foo}1': 'bar1', '{foo}2': 'bar2'}) @@ -202,7 +202,7 @@ the commands are not currently recommended for use. See documentation `__ for more. -.. code:: pycon +.. code:: python >>> p1 = rc.pubsub() # p1 connection will be set to the node that holds 'foo' keyslot @@ -224,7 +224,7 @@ READONLY mode can be set at runtime by calling the readonly() method with target_nodes=‘replicas’, and read-write access can be restored by calling the readwrite() method. -.. code:: pycon +.. code:: python >>> from cluster import RedisCluster as Redis # Use 'debug' log level to print the node that the command is executed on diff --git a/docs/conf.py b/docs/conf.py index cdbeb02c9a..8849752404 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -60,7 +60,7 @@ # General information about the project. project = "redis-py" -copyright = "2022, Redis Inc" +copyright = "2023, Redis Inc" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -287,4 +287,4 @@ epub_title = "redis-py" epub_author = "Redis Inc" epub_publisher = "Redis Inc" -epub_copyright = "2022, Redis Inc" +epub_copyright = "2023, Redis Inc" diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index 855255c88d..f7e67e2ca7 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -21,6 +21,12 @@ { "cell_type": "code", "execution_count": 1, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -36,29 +42,29 @@ "connection = redis.Redis()\n", "print(f\"Ping successful: {await connection.ping()}\")\n", "await connection.close()" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." - ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ] }, { "cell_type": "code", "execution_count": 2, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import redis.asyncio as redis\n", @@ -67,16 +73,36 @@ "await connection.close()\n", "# Or: await connection.close(close_connection_pool=False)\n", "await connection.connection_pool.disconnect()" - ], + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, this library uses version 2 of the RESP protocol. To enable RESP version 3, you will want to set `protocol` to 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "connection = redis.Redis(protocol=3)\n", + "await connection.close()\n", + "await connection.ping()" + ] + }, + { + "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } - } - }, - { - "cell_type": "markdown", + }, "source": [ "## Transactions (Multi/Exec)\n", "\n", @@ -85,17 +111,17 @@ "The commands will not be reflected in Redis until execute() is called & awaited.\n", "\n", "Usually, when performing a bulk operation, taking advantage of a “transaction” (e.g., Multi/Exec) is to be desired, as it will also add a layer of atomicity to your bulk operation." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + ] }, { "cell_type": "code", "execution_count": 3, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import redis.asyncio as redis\n", @@ -105,31 +131,31 @@ " ok1, ok2 = await (pipe.set(\"key1\", \"value1\").set(\"key2\", \"value2\").execute())\n", "assert ok1\n", "assert ok2" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "## Pub/Sub Mode\n", - "\n", - "Subscribing to specific channels:" - ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "## Pub/Sub Mode\n", + "\n", + "Subscribing to specific channels:" + ] }, { "cell_type": "code", "execution_count": 4, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -170,29 +196,29 @@ " await r.publish(\"channel:1\", STOPWORD)\n", "\n", " await future" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "Subscribing to channels matching a glob-style pattern:" - ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "Subscribing to channels matching a glob-style pattern:" + ] }, { "cell_type": "code", "execution_count": 5, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -234,16 +260,16 @@ " await r.publish(\"channel:1\", STOPWORD)\n", "\n", " await future" - ], + ] + }, + { + "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { - "name": "#%%\n" + "name": "#%% md\n" } - } - }, - { - "cell_type": "markdown", + }, "source": [ "## Sentinel Client\n", "\n", @@ -252,17 +278,17 @@ "Calling aioredis.sentinel.Sentinel.master_for or aioredis.sentinel.Sentinel.slave_for methods will return Redis clients connected to specified services monitored by Sentinel.\n", "\n", "Sentinel client will detect failover and reconnect Redis clients automatically." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import asyncio\n", @@ -277,13 +303,61 @@ "assert ok\n", "val = await r.get(\"key\")\n", "assert val == b\"value\"" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connecting to Redis instances by specifying a URL scheme.\n", + "Parameters are passed to the following schems, as parameters to the url scheme.\n", + "\n", + "Three URL schemes are supported:\n", + "\n", + "- `redis://` creates a TCP socket connection. \n", + "- `rediss://` creates a SSL wrapped TCP socket connection. \n", + "- ``unix://``: creates a Unix Domain Socket connection.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "output_type": "display_data" } - } + ], + "source": [ + "import redis.asyncio as redis\n", + "url_connection = redis.from_url(\"redis://localhost:6379?decode_responses=True\")\n", + "url_connection.ping()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To enable the RESP 3 protocol, append `protocol=3` to the URL." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "url_connection = redis.from_url(\"redis://localhost:6379?decode_responses=True&protocol=3\")\n", + "url_connection.ping()" + ] } ], "metadata": { @@ -307,4 +381,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} \ No newline at end of file +} diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index a15b4c6cc0..cddded2865 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -41,7 +41,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### by default Redis return binary responses, to decode them use decode_responses=True" + "### By default Redis return binary responses, to decode them use decode_responses=True" ] }, { @@ -67,6 +67,25 @@ "decoded_connection.ping()" ] }, + { + "cell_type": "markdown", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### By default this library uses the RESP 2 protocol. To enable RESP3, set protocol=3." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "import redis\n", + "\n", + "r = redis.Redis(protocol=3)\n", + "rcon.ping()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -99,14 +118,15 @@ }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## Connecting to a redis instance with username and password credential provider" - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "import redis\n", @@ -114,19 +134,19 @@ "creds_provider = redis.UsernamePasswordCredentialProvider(\"username\", \"password\")\n", "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" - ], - "metadata": {} + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## Connecting to a redis instance with standard credential provider" - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from typing import Tuple\n", @@ -158,19 +178,19 @@ "user_connection = redis.Redis(host=\"localhost\", port=6379,\n", " credential_provider=creds_provider)\n", "user_connection.ping()" - ], - "metadata": {} + ] }, { "cell_type": "markdown", + "metadata": {}, "source": [ "## Connecting to a redis instance first with an initial credential set and then calling the credential provider" - ], - "metadata": {} + ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "from typing import Union\n", @@ -194,8 +214,7 @@ " return self.username, self.password\n", "\n", "cred_provider = InitCredsSetCredentialProvider(username=\"init_user\", password=\"init_pass\")" - ], - "metadata": {} + ] }, { "cell_type": "markdown", @@ -222,18 +241,23 @@ "import json\n", "import cachetools.func\n", "\n", - "sm_client = boto3.client('secretsmanager')\n", - " \n", - "def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n", - " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", - " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", - " secret = sm_client.get_secret_value(secret_id, version_id)\n", - " return json.loads(secret['SecretString'])\n", - " creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n", - " return creds['username'], creds['password']\n", + "class SecretsManagerProvider(redis.CredentialProvider):\n", + " def __init__(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n", + " self.sm_client = boto3.client('secretsmanager')\n", + " self.secret_id = secret_id\n", + " self.version_id = version_id\n", + " self.version_stage = version_stage\n", + "\n", + " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", + " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", + " secret = self.sm_client.get_secret_value(secret_id, version_id)\n", + " return json.loads(secret['SecretString'])\n", + " creds = get_sm_user_credentials(self.secret_id, self.version_id, self.version_stage)\n", + " return creds['username'], creds['password']\n", "\n", - "secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", - "creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n", + "my_secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", + "creds_provider = SecretsManagerProvider(secret_id=my_secret_id)\n", "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] @@ -262,23 +286,60 @@ } ], "source": [ + "from typing import Tuple, Union\n", + "from urllib.parse import ParseResult, urlencode, urlunparse\n", + "\n", + "import botocore.session\n", "import redis\n", - "import boto3\n", - "import cachetools.func\n", + "from botocore.model import ServiceId\n", + "from botocore.signers import RequestSigner\n", + "from cachetools import TTLCache, cached\n", "\n", - "ec_client = boto3.client('elasticache')\n", + "class ElastiCacheIAMProvider(redis.CredentialProvider):\n", + " def __init__(self, user, cluster_name, region=\"us-east-1\"):\n", + " self.user = user\n", + " self.cluster_name = cluster_name\n", + " self.region = region\n", "\n", - "def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n", - " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", - " def get_iam_auth_token(user, endpoint, port, region):\n", - " return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", - " iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n", - " return iam_auth_token\n", + " session = botocore.session.get_session()\n", + " self.request_signer = RequestSigner(\n", + " ServiceId(\"elasticache\"),\n", + " self.region,\n", + " \"elasticache\",\n", + " \"v4\",\n", + " session.get_credentials(),\n", + " session.get_component(\"event_emitter\"),\n", + " )\n", + "\n", + " # Generated IAM tokens are valid for 15 minutes\n", + " @cached(cache=TTLCache(maxsize=128, ttl=900))\n", + " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " query_params = {\"Action\": \"connect\", \"User\": self.user}\n", + " url = urlunparse(\n", + " ParseResult(\n", + " scheme=\"https\",\n", + " netloc=self.cluster_name,\n", + " path=\"/\",\n", + " query=urlencode(query_params),\n", + " params=\"\",\n", + " fragment=\"\",\n", + " )\n", + " )\n", + " signed_url = self.request_signer.generate_presigned_url(\n", + " {\"method\": \"GET\", \"url\": url, \"body\": {}, \"headers\": {}, \"context\": {}},\n", + " operation_name=\"connect\",\n", + " expires_in=900,\n", + " region_name=self.region,\n", + " )\n", + " # RequestSigner only seems to work if the URL has a protocol, but\n", + " # Elasticache only accepts the URL without a protocol\n", + " # So strip it off the signed URL before returning\n", + " return (self.user, signed_url.removeprefix(\"https://\"))\n", "\n", "username = \"barshaul\"\n", + "cluster_name = \"test-001\"\n", "endpoint = \"test-001.use1.cache.amazonaws.com\"\n", - "creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n", - " endpoint=endpoint)\n", + "creds_provider = ElastiCacheIAMProvider(user=username, cluster_name=cluster_name)\n", "user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] @@ -315,7 +376,23 @@ ], "source": [ "url_connection = redis.from_url(\"redis://localhost:6379?decode_responses=True&health_check_interval=2\")\n", - "\n", + "url_connection.ping()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connecting to Redis instances by specifying a URL scheme and the RESP3 protocol.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url_connection = redis.from_url(\"redis://localhost:6379?decode_responses=True&health_check_interval=2&protocol=3\")\n", "url_connection.ping()" ] }, @@ -362,4 +439,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/docs/examples/opentelemetry/main.py b/docs/examples/opentelemetry/main.py index b140dd0148..9ef6723305 100755 --- a/docs/examples/opentelemetry/main.py +++ b/docs/examples/opentelemetry/main.py @@ -2,12 +2,11 @@ import time +import redis import uptrace from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor -import redis - tracer = trace.get_tracer("app_or_package_name", "1.0.0") diff --git a/docs/examples/pipeline_examples.ipynb b/docs/examples/pipeline_examples.ipynb index 490d2213a0..4e20375bfa 100644 --- a/docs/examples/pipeline_examples.ipynb +++ b/docs/examples/pipeline_examples.ipynb @@ -123,7 +123,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The responses of the three commands are stored in a list. In the above example, the two first boolean indicates that the the `set` commands were successfull and the last element of the list is the result of the `get(\"a\")` comand." + "The responses of the three commands are stored in a list. In the above example, the two first boolean indicates that the `set` commands were successfull and the last element of the list is the result of the `get(\"a\")` comand." ] }, { diff --git a/docs/examples/redis-stream-example.ipynb b/docs/examples/redis-stream-example.ipynb index 9303b527ca..a84bf19cb6 100644 --- a/docs/examples/redis-stream-example.ipynb +++ b/docs/examples/redis-stream-example.ipynb @@ -313,7 +313,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# stream groups\n", + "# Stream groups\n", "With the groups is possible track, for many consumers, and at the Redis side, which message have been already consumed.\n", "## add some data to streams\n", "Creating 2 streams with 10 messages each." diff --git a/docs/examples/search_vector_similarity_examples.ipynb b/docs/examples/search_vector_similarity_examples.ipynb index 2b0261097c..bd1df3c1ef 100644 --- a/docs/examples/search_vector_similarity_examples.ipynb +++ b/docs/examples/search_vector_similarity_examples.ipynb @@ -1,81 +1,643 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Vector Similarity\n", - "## Adding Vector Fields" + "**Vectors** (also called \"Embeddings\"), represent an AI model's impression (or understanding) of a piece of unstructured data like text, images, audio, videos, etc. Vector Similarity Search (VSS) is the process of finding vectors in the vector database that are similar to a given query vector. Popular VSS uses include recommendation systems, image and video search, document retrieval, and question answering.\n", + "\n", + "## Index Creation\n", + "Before doing vector search, first define the schema and create an index." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, + "outputs": [], + "source": [ + "import redis\n", + "from redis.commands.search.field import TagField, VectorField\n", + "from redis.commands.search.indexDefinition import IndexDefinition, IndexType\n", + "from redis.commands.search.query import Query\n", + "\n", + "r = redis.Redis(host=\"localhost\", port=6379)\n", + "\n", + "INDEX_NAME = \"index\" # Vector Index Name\n", + "DOC_PREFIX = \"doc:\" # RediSearch Key Prefix for the Index\n", + "\n", + "def create_index(vector_dimensions: int):\n", + " try:\n", + " # check to see if index exists\n", + " r.ft(INDEX_NAME).info()\n", + " print(\"Index already exists!\")\n", + " except:\n", + " # schema\n", + " schema = (\n", + " TagField(\"tag\"), # Tag Field Name\n", + " VectorField(\"vector\", # Vector Field Name\n", + " \"FLAT\", { # Vector Index Type: FLAT or HNSW\n", + " \"TYPE\": \"FLOAT32\", # FLOAT32 or FLOAT64\n", + " \"DIM\": vector_dimensions, # Number of Vector Dimensions\n", + " \"DISTANCE_METRIC\": \"COSINE\", # Vector Search Distance Metric\n", + " }\n", + " ),\n", + " )\n", + "\n", + " # index Definition\n", + " definition = IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH)\n", + "\n", + " # create Index\n", + " r.ft(INDEX_NAME).create_index(fields=schema, definition=definition)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll start by working with vectors that have 1536 dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# define vector dimensions\n", + "VECTOR_DIMENSIONS = 1536\n", + "\n", + "# create the index\n", + "create_index(vector_dimensions=VECTOR_DIMENSIONS)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Vectors to Redis\n", + "\n", + "Next, we add vectors (dummy data) to Redis using `hset`. The search index listens to keyspace notifications and will include any written HASH objects prefixed by `DOC_PREFIX`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install numpy" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# instantiate a redis pipeline\n", + "pipe = r.pipeline()\n", + "\n", + "# define some dummy data\n", + "objects = [\n", + " {\"name\": \"a\", \"tag\": \"foo\"},\n", + " {\"name\": \"b\", \"tag\": \"foo\"},\n", + " {\"name\": \"c\", \"tag\": \"bar\"},\n", + "]\n", + "\n", + "# write data\n", + "for obj in objects:\n", + " # define key\n", + " key = f\"doc:{obj['name']}\"\n", + " # create a random \"dummy\" vector\n", + " obj[\"vector\"] = np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + " # HSET\n", + " pipe.hset(key, mapping=obj)\n", + "\n", + "res = pipe.execute()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Searching\n", + "You can use VSS queries with the `.ft(...).search(...)` query command. To use a VSS query, you must specify the option `.dialect(2)`.\n", + "\n", + "There are two supported types of vector queries in Redis: `KNN` and `Range`. `Hybrid` queries can work in both settings and combine elements of traditional search and VSS." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### KNN Queries\n", + "KNN queries are for finding the topK most similar vectors given a query vector." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { "text/plain": [ - "b'OK'" + "[Document {'id': 'doc:b', 'payload': None, 'score': '0.2376562953'},\n", + " Document {'id': 'doc:c', 'payload': None, 'score': '0.240063905716'}]" ] }, - "execution_count": 1, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "import redis\n", - "from redis.commands.search.field import VectorField\n", - "from redis.commands.search.query import Query\n", + "query = (\n", + " Query(\"*=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"id\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", "\n", - "r = redis.Redis(host='localhost', port=36379)\n", + "query_params = {\n", + " \"vec\": np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + "}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Range Queries\n", + "Range queries provide a way to filter results by the distance between a vector field in Redis and a query vector based on some pre-defined threshold (radius)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document {'id': 'doc:a', 'payload': None, 'score': '0.243115246296'},\n", + " Document {'id': 'doc:c', 'payload': None, 'score': '0.24981123209'},\n", + " Document {'id': 'doc:b', 'payload': None, 'score': '0.251443207264'}]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = (\n", + " Query(\"@vector:[VECTOR_RANGE $radius $vec]=>{$YIELD_DISTANCE_AS: score}\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"id\", \"score\")\n", + " .paging(0, 3)\n", + " .dialect(2)\n", + ")\n", "\n", - "schema = (VectorField(\"v\", \"HNSW\", {\"TYPE\": \"FLOAT32\", \"DIM\": 2, \"DISTANCE_METRIC\": \"L2\"}),)\n", - "r.ft().create_index(schema)" + "# Find all vectors within 0.8 of the query vector\n", + "query_params = {\n", + " \"radius\": 0.8,\n", + " \"vec\": np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + "}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Searching" + "See additional Range Query examples in [this Jupyter notebook](https://github.com/RediSearch/RediSearch/blob/master/docs/docs/vecsim-range_queries_examples.ipynb)." ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "### Querying vector fields" + "### Hybrid Queries\n", + "Hybrid queries contain both traditional filters (numeric, tags, text) and VSS in one single Redis command." ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document {'id': 'doc:b', 'payload': None, 'score': '0.24422544241', 'tag': 'foo'},\n", + " Document {'id': 'doc:a', 'payload': None, 'score': '0.259926855564', 'tag': 'foo'}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "query = (\n", + " Query(\"(@tag:{ foo })=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"id\", \"tag\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", + "\n", + "query_params = {\n", + " \"vec\": np.random.rand(VECTOR_DIMENSIONS).astype(np.float32).tobytes()\n", + "}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See additional Hybrid Query examples in [this Jupyter notebook](https://github.com/RediSearch/RediSearch/blob/master/docs/docs/vecsim-hybrid_queries_examples.ipynb)." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vector Creation and Storage Examples\n", + "The above examples use dummy data as vectors. However, in reality, most use cases leverage production-grade AI models for creating embeddings. Below we will take some sample text data, pass it to the OpenAI and Cohere API's respectively, and then write them to Redis." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "texts = [\n", + " \"Today is a really great day!\",\n", + " \"The dog next door barks really loudly.\",\n", + " \"My cat escaped and got out before I could close the door.\",\n", + " \"It's supposed to rain and thunder tomorrow.\"\n", + "]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### OpenAI Embeddings\n", + "Before working with OpenAI Embeddings, we clean up our existing search index and create a new one." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# delete index\n", + "r.ft(INDEX_NAME).dropindex(delete_documents=True)\n", + "\n", + "# make a new one\n", + "create_index(vector_dimensions=VECTOR_DIMENSIONS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install openai" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "# set your OpenAI API key - get one at https://platform.openai.com\n", + "openai.api_key = \"YOUR OPENAI API KEY\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Embeddings with OpenAI text-embedding-ada-002\n", + "# https://openai.com/blog/new-and-improved-embedding-model\n", + "response = openai.Embedding.create(input=texts, engine=\"text-embedding-ada-002\")\n", + "embeddings = np.array([r[\"embedding\"] for r in response[\"data\"]], dtype=np.float32)\n", + "\n", + "# Write to Redis\n", + "pipe = r.pipeline()\n", + "for i, embedding in enumerate(embeddings):\n", + " pipe.hset(f\"doc:{i}\", mapping = {\n", + " \"vector\": embedding.tobytes(),\n", + " \"content\": texts[i],\n", + " \"tag\": \"openai\"\n", + " })\n", + "res = pipe.execute()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0.00509819, 0.0010873 , -0.00228475, ..., -0.00457579,\n", + " 0.01329307, -0.03167175],\n", + " [-0.00357223, -0.00550784, -0.01314328, ..., -0.02915693,\n", + " 0.01470436, -0.01367203],\n", + " [-0.01284631, 0.0034875 , -0.01719686, ..., -0.01537451,\n", + " 0.01953256, -0.05048691],\n", + " [-0.01145045, -0.00785481, 0.00206323, ..., -0.02070181,\n", + " -0.01629098, -0.00300795]], dtype=float32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Search with OpenAI Embeddings\n", + "\n", + "Now that we've created embeddings with OpenAI, we can also perform a search to find relevant documents to some input text.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.00062901, -0.0070723 , -0.00148926, ..., -0.01904645,\n", + " -0.00436092, -0.01117944], dtype=float32)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"animals\"\n", + "\n", + "# create query embedding\n", + "response = openai.Embedding.create(input=[text], engine=\"text-embedding-ada-002\")\n", + "query_embedding = np.array([r[\"embedding\"] for r in response[\"data\"]], dtype=np.float32)[0]\n", + "\n", + "query_embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Result{2 total, docs: [Document {'id': 'a', 'payload': None, '__v_score': '0'}, Document {'id': 'b', 'payload': None, '__v_score': '3.09485009821e+26'}]}" + "[Document {'id': 'doc:1', 'payload': None, 'score': '0.214349985123', 'content': 'The dog next door barks really loudly.', 'tag': 'openai'},\n", + " Document {'id': 'doc:2', 'payload': None, 'score': '0.237052619457', 'content': 'My cat escaped and got out before I could close the door.', 'tag': 'openai'}]" ] }, - "execution_count": 2, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "r.hset(\"a\", \"v\", \"aaaaaaaa\")\n", - "r.hset(\"b\", \"v\", \"aaaabaaa\")\n", - "r.hset(\"c\", \"v\", \"aaaaabaa\")\n", + "# query for similar documents that have the openai tag\n", + "query = (\n", + " Query(\"(@tag:{ openai })=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"content\", \"tag\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", + "\n", + "query_params = {\"vec\": query_embedding.tobytes()}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs\n", "\n", - "q = Query(\"*=>[KNN 2 @v $vec]\").return_field(\"__v_score\").dialect(2)\n", - "r.ft().search(q, query_params={\"vec\": \"aaaaaaaa\"})" + "# the two pieces of content related to animals are returned" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cohere Embeddings\n", + "Before working with Cohere Embeddings, we clean up our existing search index and create a new one." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# delete index\n", + "r.ft(INDEX_NAME).dropindex(delete_documents=True)\n", + "\n", + "# make a new one for cohere embeddings (1024 dimensions)\n", + "VECTOR_DIMENSIONS = 1024\n", + "create_index(vector_dimensions=VECTOR_DIMENSIONS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install cohere" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import cohere\n", + "\n", + "co = cohere.Client(\"YOUR COHERE API KEY\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Embeddings with Cohere\n", + "# https://docs.cohere.ai/docs/embeddings\n", + "response = co.embed(texts=texts, model=\"small\")\n", + "embeddings = np.array(response.embeddings, dtype=np.float32)\n", + "\n", + "# Write to Redis\n", + "for i, embedding in enumerate(embeddings):\n", + " r.hset(f\"doc:{i}\", mapping = {\n", + " \"vector\": embedding.tobytes(),\n", + " \"content\": texts[i],\n", + " \"tag\": \"cohere\"\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.3034668 , -0.71533203, -0.2836914 , ..., 0.81152344,\n", + " 1.0253906 , -0.8095703 ],\n", + " [-0.02560425, -1.4912109 , 0.24267578, ..., -0.89746094,\n", + " 0.15625 , -3.203125 ],\n", + " [ 0.10125732, 0.7246094 , -0.29516602, ..., -1.9638672 ,\n", + " 1.6630859 , -0.23291016],\n", + " [-2.09375 , 0.8588867 , -0.23352051, ..., -0.01541138,\n", + " 0.17053223, -3.4042969 ]], dtype=float32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Search with Cohere Embeddings\n", + "\n", + "Now that we've created embeddings with Cohere, we can also perform a search to find relevant documents to some input text." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.49682617, 1.7070312 , 0.3466797 , ..., 0.58984375,\n", + " 0.1060791 , -2.9023438 ], dtype=float32)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"animals\"\n", + "\n", + "# create query embedding\n", + "response = co.embed(texts=[text], model=\"small\")\n", + "query_embedding = np.array(response.embeddings[0], dtype=np.float32)\n", + "\n", + "query_embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document {'id': 'doc:1', 'payload': None, 'score': '0.658673524857', 'content': 'The dog next door barks really loudly.', 'tag': 'cohere'},\n", + " Document {'id': 'doc:2', 'payload': None, 'score': '0.662699103355', 'content': 'My cat escaped and got out before I could close the door.', 'tag': 'cohere'}]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# query for similar documents that have the cohere tag\n", + "query = (\n", + " Query(\"(@tag:{ cohere })=>[KNN 2 @vector $vec as score]\")\n", + " .sort_by(\"score\")\n", + " .return_fields(\"content\", \"tag\", \"score\")\n", + " .paging(0, 2)\n", + " .dialect(2)\n", + ")\n", + "\n", + "query_params = {\"vec\": query_embedding.tobytes()}\n", + "r.ft(INDEX_NAME).search(query, query_params).docs\n", + "\n", + "# the two pieces of content related to animals are returned" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Find more example apps, tutorials, and projects using Redis Vector Similarity Search [in this GitHub organization](https://github.com/RedisVentures)." ] } ], @@ -98,7 +660,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.9.12" }, "orig_nbformat": 4 }, diff --git a/docs/examples/ssl_connection_examples.ipynb b/docs/examples/ssl_connection_examples.ipynb index 386e4af452..ab3b4415ae 100644 --- a/docs/examples/ssl_connection_examples.ipynb +++ b/docs/examples/ssl_connection_examples.ipynb @@ -55,6 +55,27 @@ "url_connection.ping()" ] }, + { + "cell_type": "markdown", + "id": "04e70233", + "metadata": {}, + "source": [ + "## Connecting to a Redis instance using ConnectionPool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2903de26", + "metadata": {}, + "outputs": [], + "source": [ + "import redis\n", + "redis_pool = redis.ConnectionPool(host=\"localhost\", port=6666, connection_class=redis.SSLConnection)\n", + "ssl_connection = redis.StrictRedis(connection_pool=redis_pool) \n", + "ssl_connection.ping()" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/examples/timeseries_examples.ipynb b/docs/examples/timeseries_examples.ipynb index fefc0c8f37..691e13350a 100644 --- a/docs/examples/timeseries_examples.ipynb +++ b/docs/examples/timeseries_examples.ipynb @@ -599,6 +599,47 @@ "source": [ "ts.range(\"ts_key_incr\", \"-\", \"+\")" ] + }, + { + "cell_type": "markdown", + "source": [ + "## How to execute multi-key commands on Open Source Redis Cluster" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "data": { + "text/plain": "[{'ts_key1': [{}, 1670927124746, 2.0]}, {'ts_key2': [{}, 1670927124748, 10.0]}]" + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import redis\n", + "\n", + "r = redis.RedisCluster(host=\"localhost\", port=46379)\n", + "\n", + "# This command should be executed on all cluster nodes after creation and any re-sharding\n", + "# Please note that this command is internal and will be deprecated in the future\n", + "r.execute_command(\"timeseries.REFRESHCLUSTER\", target_nodes=\"primaries\")\n", + "\n", + "# Now multi-key commands can be executed\n", + "ts = r.ts()\n", + "ts.add(\"ts_key1\", \"*\", 2, labels={\"label1\": 1, \"label2\": 2})\n", + "ts.add(\"ts_key2\", \"*\", 10, labels={\"label1\": 1, \"label2\": 2})\n", + "ts.mget([\"label1=1\"])" + ], + "metadata": { + "collapsed": false + } } ], "metadata": { diff --git a/docs/index.rst b/docs/index.rst index a6ee05e917..2c0557cbbe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,15 +64,16 @@ Module Documentation .. toctree:: :maxdepth: 1 - backoff connections + clustering exceptions + backoff lock retry - advanced_features - clustering lua_scripting opentelemetry + resp3_features + advanced_features examples Contributing @@ -86,4 +87,4 @@ Contributing License ******* -This projectis licensed under the `MIT license `_. +This project is licensed under the `MIT license `_. diff --git a/docs/lua_scripting.rst b/docs/lua_scripting.rst index 8276dad051..0edb6b6723 100644 --- a/docs/lua_scripting.rst +++ b/docs/lua_scripting.rst @@ -24,7 +24,7 @@ The following trivial Lua script accepts two parameters: the name of a key and a multiplier value. The script fetches the value stored in the key, multiplies it with the multiplier value and returns the result. -.. code:: pycon +.. code:: python >>> r = redis.Redis() >>> lua = """ @@ -47,7 +47,7 @@ function. Script instances accept the following optional arguments: Continuing the example from above: -.. code:: pycon +.. code:: python >>> r.set('foo', 2) >>> multiply(keys=['foo'], args=[5]) @@ -60,7 +60,7 @@ executes the script and returns the result, 10. Script instances can be executed using a different client instance, even one that points to a completely different Redis server. -.. code:: pycon +.. code:: python >>> r2 = redis.Redis('redis2.example.com') >>> r2.set('foo', 3) @@ -79,7 +79,7 @@ should be passed as the client argument when calling the script. Care is taken to ensure that the script is registered in Redis's script cache just prior to pipeline execution. -.. code:: pycon +.. code:: python >>> pipe = r.pipeline() >>> pipe.set('foo', 5) diff --git a/docs/opentelemetry.rst b/docs/opentelemetry.rst index 96781028e3..d006a60461 100644 --- a/docs/opentelemetry.rst +++ b/docs/opentelemetry.rst @@ -4,7 +4,7 @@ Integrating OpenTelemetry What is OpenTelemetry? ---------------------- -`OpenTelemetry `_ is an open-source observability framework for traces, metrics, and logs. +`OpenTelemetry `_ is an open-source observability framework for traces, metrics, and logs. It is a merger of OpenCensus and OpenTracing projects hosted by Cloud Native Computing Foundation. OpenTelemetry allows developers to collect and export telemetry data in a vendor agnostic way. With OpenTelemetry, you can instrument your application once and then add or change vendors without changing the instrumentation, for example, here is a list of `popular DataDog competitors `_ that support OpenTelemetry. @@ -97,7 +97,7 @@ See `OpenTelemetry Python Tracing API `_ that supports distributed tracing, metrics, and logs. You can use it to monitor applications and set up automatic alerts to receive notifications via email, Slack, Telegram, and more. +Uptrace is an `open source APM `_ that supports distributed tracing, metrics, and logs. You can use it to monitor applications and set up automatic alerts to receive notifications via email, Slack, Telegram, and more. You can use Uptrace to monitor redis-py using this `GitHub example `_ as a starting point. @@ -111,9 +111,9 @@ Monitoring Redis Server performance In addition to monitoring redis-py client, you can also monitor Redis Server performance using OpenTelemetry Collector Agent. -OpenTelemetry Collector is a proxy/middleman between your application and a `distributed tracing tool `_ such as Uptrace or Jaeger. Collector receives telemetry data, processes it, and then exports the data to APM tools that can store it permanently. +OpenTelemetry Collector is a proxy/middleman between your application and a `distributed tracing tool `_ such as Uptrace or Jaeger. Collector receives telemetry data, processes it, and then exports the data to APM tools that can store it permanently. -For example, you can use the Redis receiver provided by Otel Collector to `monitor Redis performance `_: +For example, you can use the `OpenTelemetry Redis receiver ` provided by Otel Collector to monitor Redis performance: .. image:: images/opentelemetry/redis-metrics.png :alt: Redis metrics @@ -123,55 +123,50 @@ See introduction to `OpenTelemetry Collector `_ using alerting rules. For example, the following rule uses the group by node expression to create an alert whenever an individual Redis shard is down: +Uptrace also allows you to monitor `OpenTelemetry metrics `_ using alerting rules. For example, the following monitor uses the group by node expression to create an alert whenever an individual Redis shard is down: .. code-block:: python - # /etc/uptrace/uptrace.yml - - alerting: - rules: - - name: Redis shard is down - metrics: - - redis_up as $redis_up - query: - - group by cluster # monitor each cluster, - - group by bdb # each database, - - group by node # and each shard - - $redis_up == 0 - # shard should be down for 5 minutes to trigger an alert - for: 5m + monitors: + - name: Redis shard is down + metrics: + - redis_up as $redis_up + query: + - group by cluster # monitor each cluster, + - group by bdb # each database, + - group by node # and each shard + - $redis_up + min_allowed_value: 1 + # shard should be down for 5 minutes to trigger an alert + for_duration: 5m You can also create queries with more complex expressions. For example, the following rule creates an alert when the keyspace hit rate is lower than 75%: .. code-block:: python - # /etc/uptrace/uptrace.yml - - alerting: - rules: - - name: Redis read hit rate < 75% - metrics: - - redis_keyspace_read_hits as $hits - - redis_keyspace_read_misses as $misses - query: - - group by cluster - - group by bdb - - group by node - - $hits / ($hits + $misses) < 0.75 - for: 5m + monitors: + - name: Redis read hit rate < 75% + metrics: + - redis_keyspace_read_hits as $hits + - redis_keyspace_read_misses as $misses + query: + - group by cluster + - group by bdb + - group by node + - $hits / ($hits + $misses) as hit_rate + min_allowed_value: 0.75 + for_duration: 5m See `Alerting and Notifications `_ for details. What's next? ------------ -Next, you can learn how to configure `uptrace-python `_ to export spans, metrics, and logs to Uptrace. +Next, you can learn how to configure `uptrace-python `_ to export spans, metrics, and logs to Uptrace. You may also be interested in the following guides: -- `OpenTelemetry Django `_ -- `OpenTelemetry Flask `_ -- `OpenTelemetry FastAPI `_ -- `OpenTelemetry SQLAlchemy `_ -- `OpenTelemetry instrumentations `_ +- `OpenTelemetry Django `_ +- `OpenTelemetry Flask `_ +- `OpenTelemetry FastAPI `_ +- `OpenTelemetry SQLAlchemy `_ diff --git a/docs/redismodules.rst b/docs/redismodules.rst index 2b0b3c6533..27757cb692 100644 --- a/docs/redismodules.rst +++ b/docs/redismodules.rst @@ -44,7 +44,7 @@ These are the commands for interacting with the `RedisBloom module `_. Practically, this means that client using RESP 3 will be faster and more performant as fewer type translations occur in the client. It also means new response types like doubles, true simple strings, maps, and booleans are available. + +Connecting +----------- + +Enabling RESP3 is no different than other connections in redis-py. In all cases, the connection type must be extending by setting `protocol=3`. The following are some base examples illustrating how to enable a RESP 3 connection. + +Connect with a standard connection, but specifying resp 3: + +.. code:: python + + >>> import redis + >>> r = redis.Redis(host='localhost', port=6379, protocol=3) + >>> r.ping() + +Or using the URL scheme: + +.. code:: python + + >>> import redis + >>> r = redis.from_url("redis://localhost:6379?protocol=3") + >>> r.ping() + +Connect with async, specifying resp 3: + +.. code:: python + + >>> import redis.asyncio as redis + >>> r = redis.Redis(host='localhost', port=6379, protocol=3) + >>> await r.ping() + +The URL scheme with the async client + +.. code:: python + + >>> import redis.asyncio as Redis + >>> r = redis.from_url("redis://localhost:6379?protocol=3") + >>> await r.ping() + +Connecting to an OSS Redis Cluster with RESP 3 + +.. code:: python + + >>> from redis.cluster import RedisCluster, ClusterNode + >>> r = RedisCluster(startup_nodes=[ClusterNode('localhost', 6379), ClusterNode('localhost', 6380)], protocol=3) + >>> r.ping() + +Push notifications +------------------ + +Push notifications are a way that redis sends out of band data. The RESP 3 protocol includes a `push type `_ that allows our client to intercept these out of band messages. By default, clients will log simple messages, but redis-py includes the ability to bring your own function processor. + +This means that should you want to perform something, on a given push notification, you specify a function during the connection, as per this examples: + +.. code:: python + + >> from redis import Redis + >> + >> def our_func(message): + >> if message.find("This special thing happened"): + >> raise IOError("This was the message: \n" + message) + >> + >> r = Redis(protocol=3) + >> p = r.pubsub(push_handler_func=our_func) + +In the example above, upon receipt of a push notification, rather than log the message, in the case where specific text occurs, an IOError is raised. This example, highlights how one could start implementing a customized message handler. diff --git a/doctests/README.md b/doctests/README.md index d4b00e059f..9dd6eaeb5d 100644 --- a/doctests/README.md +++ b/doctests/README.md @@ -2,8 +2,8 @@ ## How to add an example -Create regular python file in the current folder with meaningful name. It makes sense prefix example files with -command category (e.g. string, set, list, hash, etc) to make navigation in the folder easier. Files ending in *.py* +Create regular python file in the current folder with meaningful name. It makes sense prefix example files with +command category (e.g. string, set, list, hash, etc) to make navigation in the folder easier. Files ending in *.py* are automatically run by the test suite. ### Special markup @@ -18,14 +18,15 @@ Examples are standalone python scripts, committed to the *doctests* directory. T ```bash pip install -r requirements.txt pip install -r dev_requirements.txt +pip install -r doctests/requirements.txt ``` -Note - the CI process, runs the basic ```black``` and ```isort``` linters against the examples. Assuming -the requirements above have been installed you can run ```black yourfile.py``` and ```isort yourfile.py``` +Note - the CI process, runs the basic ```black``` and ```isort``` linters against the examples. Assuming +the requirements above have been installed you can run ```black yourfile.py``` and ```isort yourfile.py``` locally to validate the linting, prior to CI. Just include necessary assertions in the example file and run ```bash sh doctests/run_examples.sh ``` -to test all examples in the current folder. \ No newline at end of file +to test all examples in the current folder. diff --git a/doctests/requirements.txt b/doctests/requirements.txt new file mode 100644 index 0000000000..209d87b9c8 --- /dev/null +++ b/doctests/requirements.txt @@ -0,0 +1,5 @@ +numpy +pandas +requests +sentence_transformers +tabulate diff --git a/doctests/search_quickstart.py b/doctests/search_quickstart.py index 243735a697..586bb74f06 100644 --- a/doctests/search_quickstart.py +++ b/doctests/search_quickstart.py @@ -287,8 +287,10 @@ # STEP_START query_fuzzy_matching res = index.search( Query( - "@description:%analitics%" # Note the typo in the word "analytics" - ).dialect(2) + "@description:%analitics%" + ).dialect( # Note the typo in the word "analytics" + 2 + ) ) print(res) # >>> Result{1 total, docs: [ @@ -311,8 +313,10 @@ # STEP_START query_fuzzy_matching_level2 res = index.search( Query( - "@description:%%analitycs%%" # Note 2 typos in the word "analytics" - ).dialect(2) + "@description:%%analitycs%%" + ).dialect( # Note 2 typos in the word "analytics" + 2 + ) ) print(res) # >>> Result{1 total, docs: [ diff --git a/doctests/search_vss.py b/doctests/search_vss.py new file mode 100644 index 0000000000..efc4638a0a --- /dev/null +++ b/doctests/search_vss.py @@ -0,0 +1,308 @@ +# EXAMPLE: search_vss +# STEP_START imports +import json +import time + +import numpy as np +import pandas as pd +import redis +import requests +from redis.commands.search.field import ( + NumericField, + TagField, + TextField, + VectorField, +) +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +from sentence_transformers import SentenceTransformer + +# STEP_END + +# STEP_START get_data +url = "https://raw.githubusercontent.com/bsbodden/redis_vss_getting_started/main/data/bikes.json" +response = requests.get(url) +bikes = response.json() +# STEP_END +# REMOVE_START +assert bikes[0]["model"] == "Jigger" +# REMOVE_END + +# STEP_START dump_data +json.dumps(bikes[0], indent=2) +# STEP_END + +# STEP_START connect +client = redis.Redis(host="localhost", port=6379, decode_responses=True) +# STEP_END + +# STEP_START connection_test +res = client.ping() +# >>> True +# STEP_END +# REMOVE_START +assert res == True +# REMOVE_END + +# STEP_START load_data +pipeline = client.pipeline() +for i, bike in enumerate(bikes, start=1): + redis_key = f"bikes:{i:03}" + pipeline.json().set(redis_key, "$", bike) +res = pipeline.execute() +# >>> [True, True, True, True, True, True, True, True, True, True, True] +# STEP_END +# REMOVE_START +assert res == [True, True, True, True, True, True, True, True, True, True, True] +# REMOVE_END + +# STEP_START get +res = client.json().get("bikes:010", "$.model") +# >>> ['Summit'] +# STEP_END +# REMOVE_START +assert res == ["Summit"] +# REMOVE_END + +# STEP_START get_keys +keys = sorted(client.keys("bikes:*")) +# >>> ['bikes:001', 'bikes:002', ..., 'bikes:011'] +# STEP_END +# REMOVE_START +assert keys[0] == "bikes:001" +# REMOVE_END + +# STEP_START generate_embeddings +descriptions = client.json().mget(keys, "$.description") +descriptions = [item for sublist in descriptions for item in sublist] +embedder = SentenceTransformer("msmarco-distilbert-base-v4") +embeddings = embedder.encode(descriptions).astype(np.float32).tolist() +VECTOR_DIMENSION = len(embeddings[0]) +# >>> 768 +# STEP_END +# REMOVE_START +assert VECTOR_DIMENSION == 768 +# REMOVE_END + +# STEP_START load_embeddings +pipeline = client.pipeline() +for key, embedding in zip(keys, embeddings): + pipeline.json().set(key, "$.description_embeddings", embedding) +pipeline.execute() +# >>> [True, True, True, True, True, True, True, True, True, True, True] +# STEP_END + +# STEP_START dump_example +res = client.json().get("bikes:010") +# >>> +# { +# "model": "Summit", +# "brand": "nHill", +# "price": 1200, +# "type": "Mountain Bike", +# "specs": { +# "material": "alloy", +# "weight": "11.3" +# }, +# "description": "This budget mountain bike from nHill performs well..." +# "description_embeddings": [ +# -0.538114607334137, +# -0.49465855956077576, +# -0.025176964700222015, +# ... +# ] +# } +# STEP_END +# REMOVE_START +assert len(res["description_embeddings"]) == 768 +# REMOVE_END + +# STEP_START create_index +schema = ( + TextField("$.model", no_stem=True, as_name="model"), + TextField("$.brand", no_stem=True, as_name="brand"), + NumericField("$.price", as_name="price"), + TagField("$.type", as_name="type"), + TextField("$.description", as_name="description"), + VectorField( + "$.description_embeddings", + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": VECTOR_DIMENSION, + "DISTANCE_METRIC": "COSINE", + }, + as_name="vector", + ), +) +definition = IndexDefinition(prefix=["bikes:"], index_type=IndexType.JSON) +res = client.ft("idx:bikes_vss").create_index( + fields=schema, definition=definition +) +# >>> 'OK' +# STEP_END +# REMOVE_START +assert res == "OK" +time.sleep(2) +# REMOVE_END + +# STEP_START validate_index +info = client.ft("idx:bikes_vss").info() +num_docs = info["num_docs"] +indexing_failures = info["hash_indexing_failures"] +# print(f"{num_docs} documents indexed with {indexing_failures} failures") +# >>> 11 documents indexed with 0 failures +# STEP_END +# REMOVE_START +assert (num_docs == "11") and (indexing_failures == "0") +# REMOVE_END + +# STEP_START simple_query_1 +query = Query("@brand:Peaknetic") +res = client.ft("idx:bikes_vss").search(query).docs +# print(res) +# >>> [Document {'id': 'bikes:008', 'payload': None, 'brand': 'Peaknetic', 'model': 'Soothe Electric bike', 'price': '1950', 'description_embeddings': ... +# STEP_END +# REMOVE_START + +assert all( + item in [x.__dict__["id"] for x in res] + for item in ["bikes:008", "bikes:009"] +) +# REMOVE_END + +# STEP_START simple_query_2 +query = Query("@brand:Peaknetic").return_fields("id", "brand", "model", "price") +res = client.ft("idx:bikes_vss").search(query).docs +# print(res) +# >>> [Document {'id': 'bikes:008', 'payload': None, 'brand': 'Peaknetic', 'model': 'Soothe Electric bike', 'price': '1950'}, Document {'id': 'bikes:009', 'payload': None, 'brand': 'Peaknetic', 'model': 'Secto', 'price': '430'}] +# STEP_END +# REMOVE_START +assert all( + item in [x.__dict__["id"] for x in res] + for item in ["bikes:008", "bikes:009"] +) +# REMOVE_END + +# STEP_START simple_query_3 +query = Query("@brand:Peaknetic @price:[0 1000]").return_fields( + "id", "brand", "model", "price" +) +res = client.ft("idx:bikes_vss").search(query).docs +# print(res) +# >>> [Document {'id': 'bikes:009', 'payload': None, 'brand': 'Peaknetic', 'model': 'Secto', 'price': '430'}] +# STEP_END +# REMOVE_START +assert all(item in [x.__dict__["id"] for x in res] for item in ["bikes:009"]) +# REMOVE_END + +# STEP_START def_bulk_queries +queries = [ + "Bike for small kids", + "Best Mountain bikes for kids", + "Cheap Mountain bike for kids", + "Female specific mountain bike", + "Road bike for beginners", + "Commuter bike for people over 60", + "Comfortable commuter bike", + "Good bike for college students", + "Mountain bike for beginners", + "Vintage bike", + "Comfortable city bike", +] +# STEP_END + +# STEP_START enc_bulk_queries +encoded_queries = embedder.encode(queries) +len(encoded_queries) +# >>> 11 +# STEP_END +# REMOVE_START +assert len(encoded_queries) == 11 +# REMOVE_END + + +# STEP_START define_bulk_query +def create_query_table(query, queries, encoded_queries, extra_params={}): + results_list = [] + for i, encoded_query in enumerate(encoded_queries): + result_docs = ( + client.ft("idx:bikes_vss") + .search( + query, + { + "query_vector": np.array( + encoded_query, dtype=np.float32 + ).tobytes() + } + | extra_params, + ) + .docs + ) + for doc in result_docs: + vector_score = round(1 - float(doc.vector_score), 2) + results_list.append( + { + "query": queries[i], + "score": vector_score, + "id": doc.id, + "brand": doc.brand, + "model": doc.model, + "description": doc.description, + } + ) + + # Optional: convert the table to Markdown using Pandas + queries_table = pd.DataFrame(results_list) + queries_table.sort_values( + by=["query", "score"], ascending=[True, False], inplace=True + ) + queries_table["query"] = queries_table.groupby("query")["query"].transform( + lambda x: [x.iloc[0]] + [""] * (len(x) - 1) + ) + queries_table["description"] = queries_table["description"].apply( + lambda x: (x[:497] + "...") if len(x) > 500 else x + ) + queries_table.to_markdown(index=False) + + +# STEP_END + +# STEP_START run_knn_query +query = ( + Query("(*)=>[KNN 3 @vector $query_vector AS vector_score]") + .sort_by("vector_score") + .return_fields("vector_score", "id", "brand", "model", "description") + .dialect(2) +) + +create_query_table(query, queries, encoded_queries) +# >>> | Best Mountain bikes for kids | 0.54 | bikes:003... (+ 32 more results) +# STEP_END + +# STEP_START run_hybrid_query +hybrid_query = ( + Query("(@brand:Peaknetic)=>[KNN 3 @vector $query_vector AS vector_score]") + .sort_by("vector_score") + .return_fields("vector_score", "id", "brand", "model", "description") + .dialect(2) +) +create_query_table(hybrid_query, queries, encoded_queries) +# >>> | Best Mountain bikes for kids | 0.3 | bikes:008... (+22 more results) +# STEP_END + +# STEP_START run_range_query +range_query = ( + Query( + "@vector:[VECTOR_RANGE $range $query_vector]=>{$YIELD_DISTANCE_AS: vector_score}" + ) + .sort_by("vector_score") + .return_fields("vector_score", "id", "brand", "model", "description") + .paging(0, 4) + .dialect(2) +) +create_query_table( + range_query, queries[:1], encoded_queries[:1], {"range": 0.55} +) +# >>> | Bike for small kids | 0.52 | bikes:001 | Velorim |... (+1 more result) +# STEP_END diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000..f1b716ae96 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,13 @@ +[pytest] +addopts = -s +markers = + redismod: run only the redis module tests + pipeline: pipeline tests + onlycluster: marks tests to be run only with cluster mode redis + onlynoncluster: marks tests to be run only with standalone redis + ssl: marker for only the ssl tests + asyncio: marker for async tests + replica: replica tests + experimental: run only experimental tests +asyncio_mode = auto +timeout = 30 diff --git a/redis/__init__.py b/redis/__init__.py index 3ab697065e..495d2d99bb 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,5 +1,6 @@ import sys +from redis import asyncio # noqa from redis.backoff import default_backoff from redis.client import Redis, StrictRedis from redis.cluster import RedisCluster @@ -19,6 +20,7 @@ ConnectionError, DataError, InvalidResponse, + OutOfMemoryError, PubSubError, ReadOnlyError, RedisError, @@ -56,7 +58,7 @@ def int_or_str(value): try: VERSION = tuple(map(int_or_str, __version__.split("."))) except AttributeError: - VERSION = tuple(99, 99, 99) + VERSION = tuple([99, 99, 99]) __all__ = [ "AuthenticationError", @@ -72,6 +74,7 @@ def int_or_str(value): "from_url", "default_backoff", "InvalidResponse", + "OutOfMemoryError", "PubSubError", "ReadOnlyError", "Redis", diff --git a/redis/_parsers/__init__.py b/redis/_parsers/__init__.py new file mode 100644 index 0000000000..6cc32e3cae --- /dev/null +++ b/redis/_parsers/__init__.py @@ -0,0 +1,20 @@ +from .base import BaseParser, _AsyncRESPBase +from .commands import AsyncCommandsParser, CommandsParser +from .encoders import Encoder +from .hiredis import _AsyncHiredisParser, _HiredisParser +from .resp2 import _AsyncRESP2Parser, _RESP2Parser +from .resp3 import _AsyncRESP3Parser, _RESP3Parser + +__all__ = [ + "AsyncCommandsParser", + "_AsyncHiredisParser", + "_AsyncRESPBase", + "_AsyncRESP2Parser", + "_AsyncRESP3Parser", + "CommandsParser", + "Encoder", + "BaseParser", + "_HiredisParser", + "_RESP2Parser", + "_RESP3Parser", +] diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py new file mode 100644 index 0000000000..f77296df6a --- /dev/null +++ b/redis/_parsers/base.py @@ -0,0 +1,232 @@ +import sys +from abc import ABC +from asyncio import IncompleteReadError, StreamReader, TimeoutError +from typing import List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from ..exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ConnectionError, + ExecAbortError, + ModuleError, + NoPermissionError, + NoScriptError, + OutOfMemoryError, + ReadOnlyError, + RedisError, + ResponseError, +) +from ..typing import EncodableT +from .encoders import Encoder +from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer + +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) +# user send an AUTH cmd to a server without authorization configured +NO_AUTH_SET_ERROR = { + # Redis >= 6.0 + "AUTH called without any password " + "configured for the default user. Are you sure " + "your configuration is correct?": AuthenticationError, + # Redis < 6.0 + "Client sent AUTH, but no password is set": AuthenticationError, +} + + +class BaseParser(ABC): + + EXCEPTION_CLASSES = { + "ERR": { + "max number of clients reached": ConnectionError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + **NO_AUTH_SET_ERROR, + }, + "OOM": OutOfMemoryError, + "WRONGPASS": AuthenticationError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + @classmethod + def parse_error(cls, response): + "Parse an error response" + error_code = response.split(" ")[0] + if error_code in cls.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = cls.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection): + raise NotImplementedError() + + +class _RESPBase(BaseParser): + """Base class for sync-based resp parsing""" + + def __init__(self, socket_read_size): + self.socket_read_size = socket_read_size + self.encoder = None + self._sock = None + self._buffer = None + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + "Called when the socket connects" + self._sock = connection._sock + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + "Called when the socket disconnects" + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + def can_read(self, timeout): + return self._buffer and self._buffer.can_read(timeout) + + +class AsyncBaseParser(BaseParser): + """Base parsing class for the python-backed async parser""" + + __slots__ = "_stream", "_read_size" + + def __init__(self, socket_read_size: int): + self._stream: Optional[StreamReader] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + async def can_read_destructive(self) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class _AsyncRESPBase(AsyncBaseParser): + """Base class for async resp parsing""" + + __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() + + def on_connect(self, connection): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + self.encoder = connection.encoder + self._clear() + self._connected = True + + def on_disconnect(self): + """Called when the stream disconnects""" + self._connected = False + + async def can_read_destructive(self) -> bool: + if not self._connected: + raise RedisError("Buffer is closed.") + if self._buffer: + return True + try: + async with async_timeout(0): + return await self._stream.read(1) + except TimeoutError: + return False + + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result diff --git a/redis/commands/parser.py b/redis/_parsers/commands.py similarity index 56% rename from redis/commands/parser.py rename to redis/_parsers/commands.py index 115230a9d2..b5109252ae 100644 --- a/redis/commands/parser.py +++ b/redis/_parsers/commands.py @@ -1,8 +1,59 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + from redis.exceptions import RedisError, ResponseError from redis.utils import str_if_bytes +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + + +class AbstractCommandsParser: + def _get_pubsub_keys(self, *args): + """ + Get the keys from pubsub command. + Although PubSub commands have predetermined key locations, they are not + supported in the 'COMMAND's output, so the key positions are hardcoded + in this method + """ + if len(args) < 2: + # The command has no keys in it + return None + args = [str_if_bytes(arg) for arg in args] + command = args[0].upper() + keys = None + if command == "PUBSUB": + # the second argument is a part of the command name, e.g. + # ['PUBSUB', 'NUMSUB', 'foo']. + pubsub_type = args[1].upper() + if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]: + keys = args[2:] + elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: + # format example: + # SUBSCRIBE channel [channel ...] + keys = list(args[1:]) + elif command in ["PUBLISH", "SPUBLISH"]: + # format example: + # PUBLISH channel message + keys = [args[1]] + return keys + + def parse_subcommand(self, command, **options): + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + return cmd_dict + -class CommandsParser: +class CommandsParser(AbstractCommandsParser): """ Parses Redis commands to get command keys. COMMAND output is used to determine key locations. @@ -16,7 +67,7 @@ def __init__(self, redis_connection): self.initialize(redis_connection) def initialize(self, r): - commands = r.execute_command("COMMAND") + commands = r.command() uppercase_commands = [] for cmd in commands: if any(x.isupper() for x in cmd): @@ -25,21 +76,6 @@ def initialize(self, r): commands[cmd.lower()] = commands.pop(cmd) self.commands = commands - def parse_subcommand(self, command, **options): - cmd_dict = {} - cmd_name = str_if_bytes(command[0]) - cmd_dict["name"] = cmd_name - cmd_dict["arity"] = int(command[1]) - cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] - cmd_dict["first_key_pos"] = command[3] - cmd_dict["last_key_pos"] = command[4] - cmd_dict["step_count"] = command[5] - if len(command) > 7: - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] - return cmd_dict - # As soon as this PR is merged into Redis, we should reimplement # our logic to use COMMAND INFO changes to determine the key positions # https://github.com/redis/redis/pull/8324 @@ -117,12 +153,9 @@ def _get_moveable_keys(self, redis_conn, *args): So, don't use this function with EVAL or EVALSHA. """ - pieces = [] - cmd_name = args[0] # The command name should be splitted into separate arguments, # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] - pieces = pieces + cmd_name.split() - pieces = pieces + list(args[1:]) + pieces = args[0].split() + list(args[1:]) try: keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) except ResponseError as e: @@ -136,31 +169,113 @@ def _get_moveable_keys(self, redis_conn, *args): raise e return keys - def _get_pubsub_keys(self, *args): + +class AsyncCommandsParser(AbstractCommandsParser): + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands", "node") + + def __init__(self) -> None: + self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} + + async def initialize(self, node: Optional["ClusterNode"] = None) -> None: + if node: + self.node = node + + commands = await self.node.execute_command("COMMAND") + self.commands = {cmd.lower(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: """ - Get the keys from pubsub command. - Although PubSub commands have predetermined key locations, they are not - supported in the 'COMMAND's output, so the key positions are hardcoded - in this method + Get the keys from the passed command. + + NOTE: Due to a bug in redis<7.0, this function does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this function with EVAL or EVALSHA. """ if len(args) < 2: # The command has no keys in it return None - args = [str_if_bytes(arg) for arg in args] - command = args[0].upper() - keys = None - if command == "PUBSUB": - # the second argument is a part of the command name, e.g. - # ['PUBSUB', 'NUMSUB', 'foo']. - pubsub_type = args[1].upper() - if pubsub_type in ["CHANNELS", "NUMSUB"]: - keys = args[2:] - elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]: - # format example: - # SUBSCRIBE channel [channel ...] - keys = list(args[1:]) - elif command == "PUBLISH": - # format example: - # PUBLISH channel message - keys = [args[1]] + + cmd_name = args[0].lower() + if cmd_name not in self.commands: + # try to split the command name and to take only the main command, + # e.g. 'memory' for 'memory usage' + cmd_name_split = cmd_name.split() + cmd_name = cmd_name_split[0] + if cmd_name in self.commands: + # save the splitted command to args + args = cmd_name_split + list(args[1:]) + else: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize() + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name.upper()} command doesn't exist in Redis commands" + ) + + command = self.commands.get(cmd_name) + if "movablekeys" in command["flags"]: + keys = await self._get_moveable_keys(*args) + elif "pubsub" in command["flags"] or command["name"] == "pubsub": + keys = self._get_pubsub_keys(*args) + else: + if ( + command["step_count"] == 0 + and command["first_key_pos"] == 0 + and command["last_key_pos"] == 0 + ): + is_subcmd = False + if "subcommands" in command: + subcmd_name = f"{cmd_name}|{args[1].lower()}" + for subcmd in command["subcommands"]: + if str_if_bytes(subcmd[0]) == subcmd_name: + command = self.parse_subcommand(subcmd) + is_subcmd = True + + # The command doesn't have keys in it + if not is_subcmd: + return None + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) - abs(last_key_pos) + keys_pos = list( + range(command["first_key_pos"], last_key_pos + 1, command["step_count"]) + ) + keys = [args[pos] for pos in keys_pos] + + return keys + + async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + try: + keys = await self.node.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e return keys diff --git a/redis/_parsers/encoders.py b/redis/_parsers/encoders.py new file mode 100644 index 0000000000..6fdf0ad882 --- /dev/null +++ b/redis/_parsers/encoders.py @@ -0,0 +1,44 @@ +from ..exceptions import DataError + + +class Encoder: + "Encode strings to bytes-like and decode bytes-like to strings" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): + return value + elif isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) + elif isinstance(value, (int, float)): + value = repr(value).encode() + elif not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = type(value).__name__ + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) + if isinstance(value, str): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py new file mode 100644 index 0000000000..fb5da831fe --- /dev/null +++ b/redis/_parsers/helpers.py @@ -0,0 +1,852 @@ +import datetime + +from redis.utils import str_if_bytes + + +def timestamp_to_datetime(response): + "Converts a unix timestamp to a Python datetime object" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def parse_debug_object(response): + "Parse the results of Redis's DEBUG OBJECT command into a Python dict" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = str_if_bytes(response) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_info(response): + """Parse the result of Redis's INFO command into a Python dict""" + info = {} + response = str_if_bytes(response) + + def get_value(value): + if "," not in value or "=" not in value: + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(","): + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith("#"): + if line.find(":") != -1: + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) + + if key == "module": + # Hardcode a list for key 'modules' since there could be + # multiple lines that started with 'module' + info.setdefault("modules", []).append(get_value(value)) + else: + info[key] = get_value(value) + else: + # if the line isn't splittable, append it to the "__raw__" key + info.setdefault("__raw__", []).append(line) + + return info + + +def parse_memory_stats(response, **kwargs): + """Parse the results of MEMORY STATS""" + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) + for key, value in stats.items(): + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) + return stats + + +SENTINEL_STATE_TYPES = { + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, +} + + +def parse_sentinel_state(item): + result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): + result[name] = flag in flags + return result + + +def parse_sentinel_master(response): + return parse_sentinel_state(map(str_if_bytes, response)) + + +def parse_sentinel_state_resp3(response): + result = {} + for key in response: + try: + value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key])) + result[str_if_bytes(key)] = value + except Exception: + result[str_if_bytes(key)] = response[str_if_bytes(key)] + flags = set(result["flags"].split(",")) + result["flags"] = flags + return result + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(map(str_if_bytes, item)) + result[state["name"]] = state + return result + + +def parse_sentinel_masters_resp3(response): + return [parse_sentinel_state(master) for master in response] + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] + + +def parse_sentinel_slaves_and_sentinels_resp3(response): + return [parse_sentinel_state_resp3(item) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None + + +def pairs_to_dict(response, decode_keys=False, decode_string_values=False): + """Create a dict given a list of key/value pairs""" + if response is None: + return {} + if decode_keys or decode_string_values: + # the iter form is faster, but I don't know how to make that work + # with a str_if_bytes() map + keys = response[::2] + if decode_keys: + keys = map(str_if_bytes, keys) + values = response[1::2] + if decode_string_values: + values = map(str_if_bytes, values) + return dict(zip(keys, values)) + else: + it = iter(response) + return dict(zip(it, it)) + + +def pairs_to_dict_typed(response, type_info): + it = iter(response) + result = {} + for key, value in zip(it, it): + if key in type_info: + try: + value = type_info[key](value) + except Exception: + # if for some reason the value can't be coerced, just use + # the string value + pass + result[key] = value + return result + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + it = iter(response) + return list(zip(it, map(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options.get("groups"): + return response + n = options["groups"] + return list(zip(*[response[i::n] for i in range(n)])) + + +def parse_stream_list(response): + if response is None: + return None + data = [] + for r in response: + if r is not None: + data.append((r[0], pairs_to_dict(r[1]))) + else: + data.append((None, None)) + return data + + +def pairs_to_dict_with_str_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(map(pairs_to_dict_with_str_keys, response)) + + +def parse_xclaim(response, **options): + if options.get("parse_justid", False): + return response + return parse_stream_list(response) + + +def parse_xautoclaim(response, **options): + if options.get("parse_justid", False): + return response[1] + response[1] = parse_stream_list(response[1]) + return response + + +def parse_xinfo_stream(response, **options): + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(k): v for k, v in response.items()} + if not options.get("full", False): + first = data.get("first-entry") + if first is not None: + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] + if last is not None: + data["last-entry"] = (last[0], pairs_to_dict(last[1])) + else: + data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} + if isinstance(data["groups"][0], list): + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + else: + data["groups"] = [ + {str_if_bytes(k): v for k, v in group.items()} + for group in data["groups"] + ] + return data + + +def parse_xread(response): + if response is None: + return [] + return [[r[0], parse_stream_list(r[1])] for r in response] + + +def parse_xread_resp3(response): + if response is None: + return {} + return {key: [parse_stream_list(value)] for key, value in response.items()} + + +def parse_xpending(response, **options): + if options.get("parse_detail", False): + return parse_xpending_range(response) + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] + return { + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, + } + + +def parse_xpending_range(response): + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") + return [dict(zip(k, r)) for r in response] + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def bool_ok(response): + return str_if_bytes(response) == "OK" + + +def parse_zadd(response, **options): + if response is None: + return None + if options.get("as_score"): + return float(response) + return int(response) + + +def parse_client_list(response, **options): + clients = [] + for c in str_if_bytes(response).splitlines(): + # Values might contain '=' + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) + return clients + + +def parse_config_get(response, **options): + response = [str_if_bytes(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + + +def parse_scan(response, **options): + cursor, r = response + return int(cursor), r + + +def parse_hscan(response, **options): + cursor, r = response + return int(cursor), r and pairs_to_dict(r) or {} + + +def parse_zscan(response, **options): + score_cast_func = options.get("score_cast_func", float) + cursor, r = response + it = iter(r) + return int(cursor), list(zip(it, map(score_cast_func, it))) + + +def parse_zmscore(response, **options): + # zmscore: list of scores (double precision floating point number) or nil + return [float(score) if score is not None else None for score in response] + + +def parse_slowlog_get(response, **options): + space = " " if options.get("decode_responses", False) else b" " + + def parse_item(item): + result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} + # Redis Enterprise injects another entry at index [3], which has + # the complexity info (i.e. the value N in case the command has + # an O(N) complexity) instead of the command. + if isinstance(item[3], list): + result["command"] = space.join(item[3]) + result["client_address"] = item[4] + result["client_name"] = item[5] + else: + result["complexity"] = item[3] + result["command"] = space.join(item[4]) + result["client_address"] = item[5] + result["client_name"] = item[6] + return result + + return [parse_item(item) for item in response] + + +def parse_stralgo(response, **options): + """ + Parse the response from `STRALGO` command. + Without modifiers the returned value is string. + When LEN is given the command returns the length of the result + (i.e integer). + When IDX is given the command returns a dictionary with the LCS + length and all the ranges in both the strings, start and end + offset for each string, where there are matches. + When WITHMATCHLEN is given, each array representing a match will + also have the length of the match at the beginning of the array. + """ + if options.get("len", False): + return int(response) + if options.get("idx", False): + if options.get("withmatchlen", False): + matches = [ + [(int(match[-1]))] + list(map(tuple, match[:-1])) + for match in response[1] + ] + else: + matches = [list(map(tuple, match)) for match in response[1]] + return { + str_if_bytes(response[0]): matches, + str_if_bytes(response[2]): int(response[3]), + } + return str_if_bytes(response) + + +def parse_cluster_info(response, **options): + response = str_if_bytes(response) + return dict(line.split(":") for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + addr = addr.split("@")[0] + node_dict = { + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": [], + "migrations": [], + "connected": True if connected == "connected" else False, + } + if len(line_items) >= 9: + slots, migrations = _parse_slots(line_items[8:]) + node_dict["slots"], node_dict["migrations"] = slots, migrations + return addr, node_dict + + +def _parse_slots(slot_ranges): + slots, migrations = [], [] + for s_range in slot_ranges: + if "->-" in s_range: + slot_id, dst_node_id = s_range[1:-1].split("->-", 1) + migrations.append( + {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} + ) + elif "-<-" in s_range: + slot_id, src_node_id = s_range[1:-1].split("-<-", 1) + migrations.append( + {"slot": slot_id, "node_id": src_node_id, "state": "importing"} + ) + else: + s_range = [sl for sl in s_range.split("-")] + slots.append(s_range) + + return slots, migrations + + +def parse_cluster_nodes(response, **options): + """ + @see: https://redis.io/commands/cluster-nodes # string / bytes + @see: https://redis.io/commands/cluster-replicas # list of string / bytes + """ + if isinstance(response, (str, bytes)): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) + + +def parse_geosearch_generic(response, **options): + """ + Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' + commands according to 'withdist', 'withhash' and 'withcoord' labels. + """ + try: + if options["store"] or options["store_dist"]: + # `store` and `store_dist` cant be combined + # with other command arguments. + # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + return response + except KeyError: # it means the command was sent via execute_command + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: + # just a bunch of places + return response_list + + cast = { + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, + } + + # zip all output results with each casting function to get + # the properly native Python value. + f = [lambda x: x] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] + + +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + commands[cmd_name] = cmd_dict + return commands + + +def parse_command_resp3(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = command[1] + cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + cmd_dict["acl_categories"] = command[6] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + + commands[cmd_name] = cmd_dict + return commands + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +def parse_client_kill(response, **options): + if isinstance(response, int): + return response + return str_if_bytes(response) == "OK" + + +def parse_acl_getuser(response, **options): + if response is None: + return None + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(key): value for key, value in response.items()} + + # convert everything but user-defined data in 'keys' to native strings + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) + if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): + data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) + if data["keys"] == [""]: + data["keys"] = [] + if "channels" in data: + if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): + data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) + if data["channels"] == [""]: + data["channels"] = [] + if "selectors" in data: + if data["selectors"] != [] and isinstance(data["selectors"][0], list): + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] + elif data["selectors"] != []: + data["selectors"] = [ + {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} + for selector in data["selectors"] + ] + + # split 'commands' into separate 'categories' and 'commands' lists + commands, categories = [], [] + for command in data["commands"].split(" "): + categories.append(command) if "@" in command else commands.append(command) + + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] + return data + + +def parse_acl_log(response, **options): + if response is None: + return None + if isinstance(response, list): + data = [] + for log in response: + log_data = pairs_to_dict(log, True, True) + client_info = log_data.get("client-info", "") + log_data["client-info"] = parse_client_info(client_info) + + # float() is lossy comparing to the "double" in C + log_data["age-seconds"] = float(log_data["age-seconds"]) + data.append(log_data) + else: + data = bool_ok(response) + return data + + +def parse_client_info(value): + """ + Parsing client-info in ACL Log in following format. + "key1=value1 key2=value2 key3=value3" + """ + client_info = {} + for info in str_if_bytes(value).strip().split(): + key, value = info.split("=") + client_info[key] = value + + # Those fields are defined as int in networking.c + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "argv-mem", + "oll", + "omem", + "tot-mem", + }: + client_info[int_key] = int(client_info[int_key]) + return client_info + + +def parse_set_result(response, **options): + """ + Handle SET result since GET argument is available since Redis 6.2. + Parsing SET result into: + - BOOL + - String when GET argument is used + """ + if options.get("get"): + # Redis will return a getCommand result. + # See `setGenericCommand` in t_string.c + return response + return response and str_if_bytes(response) == "OK" + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +_RedisCallbacks = { + **string_keys_to_dict( + "AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX " + "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE", + bool, + ), + **string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float), + **string_keys_to_dict( + "ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH", + bool_ok, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict( + "GEORADIUS GEORADIUSBYMEMBER GEOSEARCH", + parse_geosearch_generic, + ), + **string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list), + "ACL GETUSER": parse_acl_getuser, + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SETUSER": bool_ok, + "ACL SAVE": bool_ok, + "CLIENT INFO": parse_client_info, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT PAUSE": bool_ok, + "CLIENT SETINFO": bool_ok, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": bool, + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER ADDSLOTSRANGE": bool_ok, + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER DELSLOTSRANGE": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "COMMAND": parse_command, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "FUNCTION DELETE": bool_ok, + "FUNCTION FLUSH": bool_ok, + "FUNCTION RESTORE": bool_ok, + "GEODIST": float_or_none, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MODULE LOAD": bool, + "MODULE UNLOAD": bool, + "PING": lambda r: str_if_bytes(r) == "PONG", + "PUBSUB NUMSUB": parse_pubsub_numsub, + "PUBSUB SHARDNUMSUB": parse_pubsub_numsub, + "QUIT": bool_ok, + "SET": parse_set_result, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL CKQUORUM": bool_ok, + "SENTINEL FAILOVER": bool_ok, + "SENTINEL FLUSHCONFIG": bool_ok, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MONITOR": bool_ok, + "SENTINEL RESET": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SET": bool_ok, + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG RESET": bool_ok, + "SORT": sort_return_tuples, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XAUTOCLAIM": parse_xautoclaim, + "XCLAIM": parse_xclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZSCAN": parse_zscan, +} + + +_RedisCallbacksRESP2 = { + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE " + "ZREVRANGEBYSCORE ZREVRANK ZUNION", + zset_score_pairs, + ), + **string_keys_to_dict("ZINCRBY ZSCORE", float_or_none), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + **string_keys_to_dict( + "BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL GENPASS": str_if_bytes, + "ACL HELP": lambda r: list(map(str_if_bytes, r)), + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + "CONFIG GET": parse_config_get, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "MEMORY STATS": parse_memory_stats, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "RESET": str_if_bytes, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "STRALGO": parse_stralgo, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "ZADD": parse_zadd, + "ZMSCORE": parse_zmscore, +} + + +_RedisCallbacksRESP3 = { + **string_keys_to_dict( + "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " + "ZUNION HGETALL XREADGROUP", + lambda r, **kwargs: r, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r + ] + if isinstance(r, list) + else bool_ok(r), + "COMMAND": parse_command_resp3, + "CONFIG GET": lambda r: { + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None + for key, value in r.items() + }, + "MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()}, + "SENTINEL MASTER": parse_sentinel_state_resp3, + "SENTINEL MASTERS": parse_sentinel_masters_resp3, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), + "XINFO CONSUMERS": lambda r: [ + {str_if_bytes(key): value for key, value in x.items()} for x in r + ], + "XINFO GROUPS": lambda r: [ + {str_if_bytes(key): value for key, value in d.items()} for d in r + ], +} diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py new file mode 100644 index 0000000000..b3247b71ec --- /dev/null +++ b/redis/_parsers/hiredis.py @@ -0,0 +1,217 @@ +import asyncio +import socket +import sys +from typing import Callable, List, Optional, Union + +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + +from redis.compat import TypedDict + +from ..exceptions import ConnectionError, InvalidResponse, RedisError +from ..typing import EncodableT +from ..utils import HIREDIS_AVAILABLE +from .base import AsyncBaseParser, BaseParser +from .socket import ( + NONBLOCKING_EXCEPTION_ERROR_NUMBERS, + NONBLOCKING_EXCEPTIONS, + SENTINEL, + SERVER_CLOSED_CONNECTION_ERROR, +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class _HiredisParser(BaseParser): + "Parser class for connections using Hiredis" + + def __init__(self, socket_read_size): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not installed") + self.socket_read_size = socket_read_size + self._buffer = bytearray(socket_read_size) + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection, **kwargs): + import hiredis + + self._sock = connection._sock + self._socket_timeout = connection.socket_timeout + kwargs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + "errors": connection.encoder.encoding_errors, + } + + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + + def on_disconnect(self): + self._sock = None + self._reader = None + self._next_response = False + + def can_read(self, timeout): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): + sock = self._sock + custom_timeout = timeout is not SENTINEL + try: + if custom_timeout: + sock.settimeout(timeout) + bufflen = self._sock.recv_into(self._buffer) + if bufflen == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + self._reader.feed(self._buffer, 0, bufflen) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + if custom_timeout: + sock.settimeout(self._socket_timeout) + + def read_response(self, disable_decoding=False): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + + while response is False: + self.read_from_socket() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response + + +class _AsyncHiredisParser(AsyncBaseParser): + """Async implementation of parser class for connections using Hiredis""" + + __slots__ = ("_reader",) + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader = None + + def on_connect(self, connection): + import hiredis + + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._connected = True + + def on_disconnect(self): + self._connected = False + + async def can_read_destructive(self): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._reader.gets(): + return True + try: + async with async_timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: + return False + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + # If `on_disconnect()` has been called, prohibit any more reads + # even if they could happen because data might be present. + # We still allow reads in progress to finish + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response diff --git a/redis/_parsers/resp2.py b/redis/_parsers/resp2.py new file mode 100644 index 0000000000..d5adc1a898 --- /dev/null +++ b/redis/_parsers/resp2.py @@ -0,0 +1,132 @@ +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP2Parser(_RESPBase): + """RESP2 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() if self._buffer else None + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + if self._buffer: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = self._buffer.read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP2Parser(_AsyncRESPBase): + """Async class for the RESP2 protocol""" + + async def read_response(self, disable_decoding: bool = False): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + return int(response) + # bulk response + elif byte == b"$" and response == b"-1": + return None + elif byte == b"$": + response = await self._read(int(response)) + # multi-bulk response + elif byte == b"*" and response == b"-1": + return None + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding)) + for _ in range(int(response)) # noqa + ] + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py new file mode 100644 index 0000000000..1275686710 --- /dev/null +++ b/redis/_parsers/resp3.py @@ -0,0 +1,261 @@ +from logging import getLogger +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP3Parser(_RESPBase): + """RESP3 protocol implementation""" + + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + def read_response(self, disable_decoding=False, push_request=False): + pos = self._buffer.get_pos() if self._buffer else None + try: + result = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + except BaseException: + if self._buffer: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False, push_request=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = self._buffer.read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response + elif byte == b"$": + response = self._buffer.read(int(response)) + # verbatim string response + elif byte == b"=": + response = self._buffer.read(int(response))[4:] + # array response + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ + self._read_response(disable_decoding=disable_decoding) + for _ in range(int(response)) + ] + try: + response = set(response) + except TypeError: + pass + # map response + elif byte == b"%": + # we use this approach and not dict comprehension here + # because this dict comprehension fails in python 3.7 + resp_dict = {} + for _ in range(int(response)): + key = self._read_response(disable_decoding=disable_decoding) + resp_dict[key] = self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + response = resp_dict + # push response + elif byte == b">": + response = [ + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return res + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func + + +class _AsyncRESP3Parser(_AsyncRESPBase): + def __init__(self, socket_read_size): + super().__init__(socket_read_size) + self.push_handler_func = self.handle_push_response + + def handle_push_response(self, response): + logger = getLogger("push_response") + logger.info("Push response: " + str(response)) + return response + + async def read_response( + self, disable_decoding: bool = False, push_request: bool = False + ): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False, push_request: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # if byte not in (b"-", b"+", b":", b"$", b"*"): + # raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = await self._read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response + elif byte == b"$": + response = await self._read(int(response)) + # verbatim string response + elif byte == b"=": + response = (await self._read(int(response)))[4:] + # array response + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + # redis can return unhashable types (like dict) in a set, + # so we need to first convert to a list, and then try to convert it to a set + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + try: + response = set(response) + except TypeError: + pass + # map response + elif byte == b"%": + response = { + (await self._read_response(disable_decoding=disable_decoding)): ( + await self._read_response(disable_decoding=disable_decoding) + ) + for _ in range(int(response)) + } + # push response + elif byte == b">": + response = [ + ( + await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + for _ in range(int(response)) + ] + res = self.push_handler_func(response) + if not push_request: + return await ( + self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + ) + else: + return res + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func diff --git a/redis/_parsers/socket.py b/redis/_parsers/socket.py new file mode 100644 index 0000000000..8147243bba --- /dev/null +++ b/redis/_parsers/socket.py @@ -0,0 +1,162 @@ +import errno +import io +import socket +from io import SEEK_END +from typing import Optional, Union + +from ..exceptions import ConnectionError, TimeoutError +from ..utils import SSL_AVAILABLE + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} + +if SSL_AVAILABLE: + import ssl + + if hasattr(ssl, "SSLWantReadError"): + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 + else: + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +SENTINEL = object() + +SYM_CRLF = b"\r\n" + + +class SocketBuffer: + def __init__( + self, socket: socket.socket, socket_read_size: int, socket_timeout: float + ): + self._sock = socket + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer = io.BytesIO() + + def unread_bytes(self) -> int: + """ + Remaining unread length of buffer + """ + pos = self._buffer.tell() + end = self._buffer.seek(0, SEEK_END) + self._buffer.seek(pos) + return end - pos + + def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, object] = SENTINEL, + raise_on_timeout: Optional[bool] = True, + ) -> bool: + sock = self._sock + socket_read_size = self.socket_read_size + marker = 0 + custom_timeout = timeout is not SENTINEL + + buf = self._buffer + current_pos = buf.tell() + buf.seek(0, SEEK_END) + if custom_timeout: + sock.settimeout(timeout) + try: + while True: + data = self._sock.recv(socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + marker += data_length + + if length is not None and length > marker: + continue + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + buf.seek(current_pos) + if custom_timeout: + sock.settimeout(self.socket_timeout) + + def can_read(self, timeout: float) -> bool: + return bool(self.unread_bytes()) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # BufferIO will return less than requested if buffer is short + data = self._buffer.read(length) + missing = length - len(data) + if missing: + # fill up the buffer and read the remainder + self._read_from_socket(missing) + data += self._buffer.read(missing) + return data[:-2] + + def readline(self) -> bytes: + buf = self._buffer + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + data += buf.readline() + + return data[:-2] + + def get_pos(self) -> int: + """ + Get current read position + """ + return self._buffer.tell() + + def rewind(self, pos: int) -> None: + """ + Rewind the buffer to a specific position, to re-start reading + """ + self._buffer.seek(pos) + + def purge(self) -> None: + """ + After a successful read, purge the read part of buffer + """ + unread = self.unread_bytes() + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self._buffer.seek(0) + + def close(self) -> None: + try: + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._sock = None diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index bf90dde555..3545ab44c2 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -7,7 +7,6 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -24,6 +23,7 @@ ConnectionError, DataError, InvalidResponse, + OutOfMemoryError, PubSubError, ReadOnlyError, RedisError, @@ -38,7 +38,6 @@ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", - "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", @@ -47,6 +46,7 @@ "default_backoff", "InvalidResponse", "PubSubError", + "OutOfMemoryError", "ReadOnlyError", "Redis", "RedisCluster", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index e56fd022fc..f0c1ab7536 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -24,6 +24,12 @@ cast, ) +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) from redis.asyncio.connection import ( Connection, ConnectionPool, @@ -37,7 +43,6 @@ NEVER_DECODE, AbstractRedis, CaseInsensitiveDict, - bool_ok, ) from redis.commands import ( AsyncCoreCommands, @@ -57,7 +62,13 @@ WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT -from redis.utils import safe_str, str_if_bytes +from redis.utils import ( + HIREDIS_AVAILABLE, + _set_info_logger, + get_lib_version, + safe_str, + str_if_bytes, +) PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -99,7 +110,13 @@ class Redis( response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] @classmethod - def from_url(cls, url: str, **kwargs): + def from_url( + cls, + url: str, + single_connection_client: bool = False, + auto_close_connection_pool: bool = True, + **kwargs, + ): """ Return a Redis client object configured from the given URL @@ -123,10 +140,13 @@ def from_url(cls, url: str, **kwargs): There are several ways to specify a database number. The first value found will be used: - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + + 2. If using the redis:// or rediss:// schemes, the path argument of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. + + 3. A ``db`` keyword argument to this function. If none of these options are specified, the default db=0 is used. @@ -140,7 +160,12 @@ class initializer. In the case of conflicting arguments, querystring """ connection_pool = ConnectionPool.from_url(url, **kwargs) - return cls(connection_pool=connection_pool) + redis = cls( + connection_pool=connection_pool, + single_connection_client=single_connection_client, + ) + redis.auto_close_connection_pool = auto_close_connection_pool + return redis def __init__( self, @@ -171,11 +196,14 @@ def __init__( single_connection_client: bool = False, health_check_interval: int = 0, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Redis client. @@ -212,7 +240,10 @@ def __init__( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, "redis_connect_func": redis_connect_func, + "protocol": protocol, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -251,7 +282,17 @@ def __init__( self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) + + if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.response_callbacks.update(_RedisCallbacksRESP3) + else: + self.response_callbacks.update(_RedisCallbacksRESP2) + + # If using a single connection client, we need to lock creation-of and use-of + # the client in order to avoid race conditions such as using asyncio.gather + # on a set of redis commands + self._single_conn_lock = asyncio.Lock() def __repr__(self): return f"{self.__class__.__name__}<{self.connection_pool!r}>" @@ -260,8 +301,10 @@ def __await__(self): return self.initialize().__await__() async def initialize(self: _RedisT) -> _RedisT: - if self.single_connection_client and self.connection is None: - self.connection = await self.connection_pool.get_connection("_") + if self.single_connection_client: + async with self._single_conn_lock: + if self.connection is None: + self.connection = await self.connection_pool.get_connection("_") return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -448,7 +491,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): _DEL_MESSAGE = "Unclosed Redis client" def __del__(self, _warnings: Any = warnings) -> None: - if self.connection is not None: + if hasattr(self, "connection") and (self.connection is not None): _warnings.warn( f"Unclosed client session {self!r}", ResourceWarning, source=self ) @@ -501,6 +544,8 @@ async def execute_command(self, *args, **options): command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) + if self.single_connection_client: + await self._single_conn_lock.acquire() try: return await conn.retry.call_with_retry( lambda: self._send_command_parse_response( @@ -509,6 +554,8 @@ async def execute_command(self, *args, **options): lambda error: self._disconnect_raise(conn, error), ) finally: + if self.single_connection_client: + self._single_conn_lock.release() if not self.connection: await pool.release(conn) @@ -642,6 +689,7 @@ def __init__( shard_hint: Optional[str] = None, ignore_subscribe_messages: bool = False, encoder=None, + push_handler_func: Optional[Callable] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -650,18 +698,21 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: - self.health_check_response: Iterable[Union[str, bytes]] = [ - "pong", + self.health_check_response = [ + ["pong", self.HEALTH_CHECK_MESSAGE], self.HEALTH_CHECK_MESSAGE, ] else: self.health_check_response = [ - b"pong", + [b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE)], self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] + if self.push_handler_func is None: + _set_info_logger() self.channels = {} self.pending_unsubscribe_channels = set() self.patterns = {} @@ -691,6 +742,11 @@ async def reset(self): self.pending_unsubscribe_patterns = set() def close(self) -> Awaitable[NoReturn]: + # In case a connection property does not yet exist + # (due to a crash earlier in the Redis() constructor), return + # immediately as there is nothing to clean-up. + if not hasattr(self, "connection"): + return return self.reset() async def on_connect(self, connection: Connection): @@ -741,6 +797,8 @@ async def connect(self): self.connection.register_connect_callback(self.on_connect) else: await self.connection.connect() + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ @@ -781,9 +839,15 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, + conn.read_response, + timeout=read_timeout, + disconnect_on_error=False, + push_request=True, + ) - if conn.health_check_interval and response == self.health_check_response: + if conn.health_check_interval and response in self.health_check_response: # ignore the health check message as user might not expect it return None return response @@ -911,8 +975,8 @@ def ping(self, message=None) -> Awaitable: """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) async def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -920,6 +984,10 @@ async def handle_message(self, response, ignore_subscribe_messages=False): with a message handler, the handler is invoked instead of a parsed message being returned. """ + if response is None: + return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 5a2dffdd1d..84407116ed 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,27 +5,28 @@ import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, ) -from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import ( - Connection, - DefaultParser, - Encoder, - SSLConnection, - parse_url, +from redis._parsers import AsyncCommandsParser, Encoder +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, ) +from redis.asyncio.client import ResponseCallbackT +from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import default_backoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis @@ -61,7 +62,7 @@ TryAgainError, ) from redis.typing import AnyKeyT, EncodableT, KeyT -from redis.utils import dict_merge, safe_str, str_if_bytes +from redis.utils import dict_merge, get_lib_version, safe_str, str_if_bytes TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -147,6 +148,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. | Rest of the arguments will be passed to the :class:`~redis.asyncio.connection.Connection` instances when created @@ -230,6 +237,8 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), # Encoding related kwargs encoding: str = "utf-8", encoding_errors: str = "strict", @@ -241,7 +250,7 @@ def __init__( socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, socket_timeout: Optional[float] = None, retry: Optional["Retry"] = None, - retry_on_error: Optional[List[Exception]] = None, + retry_on_error: Optional[List[Type[Exception]]] = None, # SSL related kwargs ssl: bool = False, ssl_ca_certs: Optional[str] = None, @@ -250,6 +259,8 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + protocol: Optional[int] = 2, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -279,6 +290,8 @@ def __init__( "username": username, "password": password, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, # Encoding related kwargs "encoding": encoding, "encoding_errors": encoding_errors, @@ -290,6 +303,7 @@ def __init__( "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, "retry": retry, + "protocol": protocol, } if ssl: @@ -322,7 +336,11 @@ def __init__( self.retry.update_supported_errors(retry_on_error) kwargs.update({"retry": self.retry}) - kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + kwargs["response_callbacks"] = _RedisCallbacks.copy() + if kwargs.get("protocol") in ["3", 3]: + kwargs["response_callbacks"].update(_RedisCallbacksRESP3) + else: + kwargs["response_callbacks"].update(_RedisCallbacksRESP2) self.connection_kwargs = kwargs if startup_nodes: @@ -337,14 +355,19 @@ def __init__( if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + address_remap=address_remap, + ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 - self.commands_parser = CommandsParser() + self.commands_parser = AsyncCommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] @@ -1044,6 +1067,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "address_remap", ) def __init__( @@ -1051,10 +1075,12 @@ def __init__( startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1213,6 +1239,7 @@ async def initialize(self) -> None: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: @@ -1231,6 +1258,7 @@ async def initialize(self) -> None: for replica_node in replica_nodes: host = replica_node[0] port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) @@ -1304,6 +1332,16 @@ async def close(self, attr: str = "nodes_cache") -> None: ) ) + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2c75d4fcf1..c1cc1d310c 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -5,8 +5,10 @@ import os import socket import ssl +import sys import threading import weakref +from abc import abstractmethod from itertools import chain from types import MappingProxyType from typing import ( @@ -24,35 +26,38 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse -import async_timeout +# the functionality is available in 3.11.x but has a major issue before +# 3.11.3. See https://github.com/redis/redis-py/issues/2633 +if sys.version_info >= (3, 11, 3): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.compat import Protocol, TypedDict +from redis.connection import DEFAULT_RESP_VERSION from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.typing import EncodableT, EncodedT -from redis.utils import HIREDIS_AVAILABLE, str_if_bytes - -hiredis = None -if HIREDIS_AVAILABLE: - import hiredis +from redis.typing import EncodableT +from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes + +from .._parsers import ( + BaseParser, + Encoder, + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, +) SYM_STAR = b"*" SYM_DOLLAR = b"$" @@ -60,403 +65,48 @@ SYM_LF = b"\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - class _Sentinel(enum.Enum): sentinel = object() SENTINEL = _Sentinel.sentinel -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class _HiredisReaderArgs(TypedDict, total=False): - protocolError: Callable[[str], Exception] - replyError: Callable[[str], Exception] - encoding: Optional[str] - errors: Optional[str] - - -class Encoder: - """Encode strings to bytes-like and decode bytes-like to strings""" - - __slots__ = "encoding", "encoding_errors", "decode_responses" - - def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value: EncodableT) -> EncodedT: - """Return a bytestring or bytes-like representation of the value""" - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - if isinstance(value, (bytes, memoryview)): - return value - if isinstance(value, (int, float)): - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) - return repr(value).encode() - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - - def decode(self, value: EncodableT, force=False) -> EncodableT: - """Return a unicode string from the bytes-like representation""" - if self.decode_responses or force: - if isinstance(value, bytes): - return value.decode(self.encoding, self.encoding_errors) - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) - return value - - -ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] - - -class BaseParser: - """Plain Python parsing class""" - - __slots__ = "_stream", "_read_size" - - EXCEPTION_CLASSES: ExceptionMappingT = { - "ERR": { - "max number of clients reached": ConnectionError, - "Client sent AUTH, but no password is set": AuthenticationError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def __init__(self, socket_read_size: int): - self._stream: Optional[asyncio.StreamReader] = None - self._read_size = socket_read_size - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def parse_error(self, response: str) -> ResponseError: - """Parse an error response""" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - def on_disconnect(self): - raise NotImplementedError() - - def on_connect(self, connection: "Connection"): - raise NotImplementedError() - - async def can_read_destructive(self) -> bool: - raise NotImplementedError() - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: - raise NotImplementedError() - - -class PythonParser(BaseParser): - """Plain Python parsing class""" - - __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") - - def __init__(self, socket_read_size: int): - super().__init__(socket_read_size) - self.encoder: Optional[Encoder] = None - self._buffer = b"" - self._chunks = [] - self._pos = 0 - - def _clear(self): - self._buffer = b"" - self._chunks.clear() - - def on_connect(self, connection: "Connection"): - """Called when the stream connects""" - self._stream = connection._reader - if self._stream is None: - raise RedisError("Buffer is closed.") - - self.encoder = connection.encoder - - def on_disconnect(self): - """Called when the stream disconnects""" - if self._stream is not None: - self._stream = None - self.encoder = None - self._clear() - - async def can_read_destructive(self) -> bool: - if self._buffer: - return True - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - async with async_timeout.timeout(0): - return await self._stream.read(1) - except asyncio.TimeoutError: - return False - - async def read_response(self, disable_decoding: bool = False): - if self._chunks: - # augment parsing buffer with previously read data - self._buffer += b"".join(self._chunks) - self._chunks.clear() - self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - # Successfully parsing a response allows us to clear our parsing buffer - self._clear() - return response - - async def _read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None]: - if not self._stream or not self.encoder: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - raw = await self._readline() - response: Any - byte, response = raw[:1], raw[1:] - - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - self._clear() # Successful parse - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - response = int(response) - # bulk response - elif byte == b"$": - length = int(response) - if length == -1: - return None - response = await self._read(length) - # multi-bulk response - elif byte == b"*": - length = int(response) - if length == -1: - return None - response = [ - (await self._read_response(disable_decoding)) for _ in range(length) - ] - if isinstance(response, bytes) and disable_decoding is False: - response = self.encoder.decode(response) - return response - - async def _read(self, length: int) -> bytes: - """ - Read `length` bytes of data. These are assumed to be followed - by a '\r\n' terminator which is subsequently discarded. - """ - want = length + 2 - end = self._pos + want - if len(self._buffer) >= end: - result = self._buffer[self._pos : end - 2] - else: - tail = self._buffer[self._pos :] - try: - data = await self._stream.readexactly(want - len(tail)) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += want - return result - - async def _readline(self) -> bytes: - """ - read an unknown number of bytes up to the next '\r\n' - line separator, which is discarded. - """ - found = self._buffer.find(b"\r\n", self._pos) - if found >= 0: - result = self._buffer[self._pos : found] - else: - tail = self._buffer[self._pos :] - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += len(result) + 2 - return result - - -class HiredisParser(BaseParser): - """Parser class for connections using Hiredis""" - - __slots__ = BaseParser.__slots__ + ("_reader",) - - def __init__(self, socket_read_size: int): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not available.") - super().__init__(socket_read_size=socket_read_size) - self._reader: Optional[hiredis.Reader] = None - - def on_connect(self, connection: "Connection"): - self._stream = connection._reader - kwargs: _HiredisReaderArgs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - } - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - kwargs["errors"] = connection.encoder.encoding_errors - - self._reader = hiredis.Reader(**kwargs) - - def on_disconnect(self): - self._stream = None - self._reader = None - - async def can_read_destructive(self): - if not self._stream or not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._reader.gets(): - return True - try: - async with async_timeout.timeout(0): - return await self.read_from_socket() - except asyncio.TimeoutError: - return False - async def read_from_socket(self): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, List[EncodableT]]: - if not self._stream or not self._reader: - self.on_disconnect() - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - - response = self._reader.gets() - while response is False: - await self.read_from_socket() - response = self._reader.gets() - - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - - -DefaultParser: Type[Union[PythonParser, HiredisParser]] +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _AsyncHiredisParser else: - DefaultParser = PythonParser + DefaultParser = _AsyncRESP2Parser class ConnectCallbackProtocol(Protocol): - def __call__(self, connection: "Connection"): + def __call__(self, connection: "AbstractConnection"): ... class AsyncConnectCallbackProtocol(Protocol): - async def __call__(self, connection: "Connection"): + async def __call__(self, connection: "AbstractConnection"): ... ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] -class Connection: - """Manages TCP communication to and from a Redis server""" +class AbstractConnection: + """Manages communication to and from a Redis server""" __slots__ = ( "pid", - "host", - "port", "db", "username", "client_name", + "lib_name", + "lib_version", "credential_provider", "password", "socket_timeout", "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_type", "redis_connect_func", "retry_on_timeout", "retry_on_error", @@ -465,6 +115,7 @@ class Connection: "last_active_at", "encoder", "ssl_context", + "protocol", "_reader", "_writer", "_parser", @@ -478,15 +129,10 @@ class Connection: def __init__( self, *, - host: str = "localhost", - port: Union[str, int] = 6379, db: Union[str, int] = 0, password: Optional[str] = None, socket_timeout: Optional[float] = None, socket_connect_timeout: Optional[float] = None, - socket_keepalive: bool = False, - socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, - socket_type: int = 0, retry_on_timeout: bool = False, retry_on_error: Union[list, _Sentinel] = SENTINEL, encoding: str = "utf-8", @@ -496,11 +142,14 @@ def __init__( socket_read_size: int = 65536, health_check_interval: float = 0, client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): if (username or password) and credential_provider is not None: raise DataError( @@ -510,18 +159,17 @@ def __init__( "2. 'credential_provider'" ) self.pid = os.getpid() - self.host = host - self.port = int(port) self.db = db self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version self.credential_provider = credential_provider self.password = password self.username = username self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None - self.socket_keepalive = socket_keepalive - self.socket_keepalive_options = socket_keepalive_options or {} - self.socket_type = socket_type + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -542,7 +190,6 @@ def __init__( self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check: float = -1 - self.ssl_context: Optional[RedisSSLContext] = None self.encoder = encoder_class(encoding, encoding_errors, decode_responses) self.redis_connect_func = redis_connect_func self._reader: Optional[asyncio.StreamReader] = None @@ -551,28 +198,24 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + self.protocol = protocol def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) return f"{self.__class__.__name__}<{repr_args}>" + @abstractmethod def repr_pieces(self): - pieces = [("host", self.host), ("port", self.port), ("db", self.db)] - if self.client_name: - pieces.append(("client_name", self.client_name)) - return pieces - - def __del__(self): - try: - if self.is_connected: - loop = asyncio.get_running_loop() - coro = self.disconnect() - if loop.is_running(): - loop.create_task(coro) - else: - loop.run_until_complete(coro) - except Exception: - pass + pass @property def is_connected(self): @@ -631,56 +274,24 @@ async def connect(self): if task and inspect.isawaitable(task): await task + @abstractmethod async def _connect(self): - """Create a TCP socket connection""" - async with async_timeout.timeout(self.socket_connect_timeout): - reader, writer = await asyncio.open_connection( - host=self.host, - port=self.port, - ssl=self.ssl_context.get() if self.ssl_context else None, - ) - self._reader = reader - self._writer = writer - sock = writer.transport.get_extra_info("socket") - if sock: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - try: - # TCP_KEEPALIVE - if self.socket_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - for k, v in self.socket_keepalive_options.items(): - sock.setsockopt(socket.SOL_TCP, k, v) + pass - except (OSError, TypeError): - # `socket_keepalive_options` might contain invalid options - # causing an error. Do not leave the connection open. - writer.close() - raise + @abstractmethod + def _host_error(self) -> str: + pass - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - if not exception.args: - # asyncio has a bug where on Connection reset by peer, the - # exception is not instanciated, so args is empty. This is the - # workaround. - # See: https://github.com/redis/redis-py/issues/2237 - # See: https://github.com/python/cpython/issues/94061 - return ( - f"Error connecting to {self.host}:{self.port}. Connection reset by peer" - ) - elif len(exception.args) == 1: - return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." - else: - return ( - f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " - f"{exception.args[0]}." - ) + @abstractmethod + def _error_message(self, exception: BaseException) -> str: + pass async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) + parser = self._parser + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -688,8 +299,25 @@ async def on_connect(self) -> None: or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = await self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + elif auth_args: await self.send_command("AUTH", *auth_args, check_health=False) try: @@ -705,22 +333,50 @@ async def on_connect(self) -> None: if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol) + response = await self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: await self.send_command("CLIENT", "SETNAME", self.client_name) if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Error setting client name") - # if a database is specified, switch to it + # set the library name and version, pipeline for lower startup latency + if self.lib_name: + await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + if self.lib_version: + await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + # if a database is specified, switch to it. Also pipeline this if self.db: await self.send_command("SELECT", self.db) + + # read responses from pipeline + for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): + try: + await self.read_response() + except ResponseError: + pass + + if self.db: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") async def disconnect(self, nowait: bool = False) -> None: """Disconnects from the Redis server""" try: - async with async_timeout.timeout(self.socket_connect_timeout): + async with async_timeout(self.socket_connect_timeout): self._parser.on_disconnect() if not self.is_connected: return @@ -796,7 +452,11 @@ async def send_packed_command( raise ConnectionError( f"Error {err_no} while writing to socket. {errmsg}." ) from e - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. await self.disconnect(nowait=True) raise @@ -812,45 +472,61 @@ async def can_read_destructive(self): return await self._parser.can_read_destructive() except OSError as e: await self.disconnect(nowait=True) - raise ConnectionError( - f"Error while reading from {self.host}:{self.port}: {e.args}" - ) + host_error = self._host_error() + raise ConnectionError(f"Error while reading from {host_error}: {e.args}") async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + *, + disconnect_on_error: bool = True, + push_request: Optional[bool] = False, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout + host_error = self._host_error() try: - if read_timeout is not None: - async with async_timeout.timeout(read_timeout): + if ( + read_timeout is not None + and self.protocol in ["3", 3] + and not HIREDIS_AVAILABLE + ): + async with async_timeout(read_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + elif read_timeout is not None: + async with async_timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) + elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + response = await self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) else: response = await self._parser.read_response( disable_decoding=disable_decoding ) except asyncio.TimeoutError: if timeout is not None: - # user requested timeout, return None + # user requested timeout, return None. Operation can be retried return None # it was a self.socket_timeout error. - await self.disconnect(nowait=True) - raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + if disconnect_on_error: + await self.disconnect(nowait=True) + raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - await self.disconnect(nowait=True) - raise ConnectionError( - f"Error while reading from {self.host}:{self.port} : {e.args}" - ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) + raise ConnectionError(f"Error while reading from {host_error} : {e.args}") + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) raise if self.health_check_interval: @@ -922,7 +598,8 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] or chunklen > buffer_cutoff or isinstance(chunk, memoryview) ): - output.append(SYM_EMPTY.join(pieces)) + if pieces: + output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] @@ -937,7 +614,90 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] return output +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connection_arguments(self) -> Mapping: + return {"host": self.host, "port": self.port} + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + **self._connection_arguments() + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _host_error(self) -> str: + return f"{self.host}:{self.port}" + + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if not exception.args: + # asyncio has a bug where on Connection reset by peer, the + # exception is not instanciated, so args is empty. This is the + # workaround. + # See: https://github.com/redis/redis-py/issues/2237 + # See: https://github.com/python/cpython/issues/94061 + return f"Error connecting to {host_error}. Connection reset by peer" + elif len(exception.args) == 1: + return f"Error connecting to {host_error}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[0]}." + ) + + class SSLConnection(Connection): + """Manages SSL connections to and from the Redis server(s). + This class extends the Connection class, adding SSL functionality, and making + use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) + """ + def __init__( self, ssl_keyfile: Optional[str] = None, @@ -948,7 +708,6 @@ def __init__( ssl_check_hostname: bool = False, **kwargs, ): - super().__init__(**kwargs) self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, @@ -957,6 +716,12 @@ def __init__( ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, ) + super().__init__(**kwargs) + + def _connection_arguments(self) -> Mapping: + kwargs = super()._connection_arguments() + kwargs["ssl"] = self.ssl_context.get() + return kwargs @property def keyfile(self): @@ -1036,77 +801,12 @@ def get(self) -> ssl.SSLContext: return self.context -class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] - def __init__( - self, - *, - path: str = "", - db: Union[str, int] = 0, - username: Optional[str] = None, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - retry_on_timeout: bool = False, - retry_on_error: Union[list, _Sentinel] = SENTINEL, - parser_class: Type[BaseParser] = DefaultParser, - socket_read_size: int = 65536, - health_check_interval: float = 0.0, - client_name: str = None, - retry: Optional[Retry] = None, - redis_connect_func=None, - credential_provider: Optional[CredentialProvider] = None, - ): - """ - Initialize a new UnixDomainSocketConnection. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object - """ - if (username or password) and credential_provider is not None: - raise DataError( - "'username' and 'password' cannot be passed along with 'credential_" - "provider'. Please provide only one of the following arguments: \n" - "1. 'password' and (optional) 'username'\n" - "2. 'credential_provider'" - ) - self.pid = os.getpid() +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, *, path: str = "", **kwargs): self.path = path - self.db = db - self.client_name = client_name - self.credential_provider = credential_provider - self.password = password - self.username = username - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None - self.retry_on_timeout = retry_on_timeout - if retry_on_error is SENTINEL: - retry_on_error = [] - if retry_on_timeout: - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) - else: - self.retry = Retry(NoBackoff(), 0) - self.health_check_interval = health_check_interval - self.next_health_check = -1 - self.redis_connect_func = redis_connect_func - self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self._sock = None - self._reader = None - self._writer = None - self._socket_read_size = socket_read_size - self.set_parser(parser_class) - self._connect_callbacks = [] - self._buffer_cutoff = 6000 + super().__init__(**kwargs) def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: pieces = [("path", self.path), ("db", self.db)] @@ -1115,21 +815,27 @@ def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: return pieces async def _connect(self): - async with async_timeout.timeout(self.socket_connect_timeout): + async with async_timeout(self.socket_connect_timeout): reader, writer = await asyncio.open_unix_connection(path=self.path) self._reader = reader self._writer = writer await self.on_connect() - def _error_message(self, exception): + def _host_error(self) -> str: + return self.path + + def _error_message(self, exception: BaseException) -> str: # args for socket.error can either be (errno, "message") # or just "message" + host_error = self._host_error() if len(exception.args) == 1: - return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) else: return ( f"Error {exception.args[0]} connecting to unix socket: " - f"{self.path}. {exception.args[1]}." + f"{host_error}. {exception.args[1]}." ) @@ -1161,7 +867,7 @@ def to_bool(value) -> Optional[bool]: class ConnectKwargs(TypedDict, total=False): username: str password: str - connection_class: Type[Connection] + connection_class: Type[AbstractConnection] host: str port: int db: int @@ -1262,10 +968,13 @@ def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: There are several ways to specify a database number. The first value found will be used: - 1. A ``db`` querystring option, e.g. redis://localhost?db=0 - 2. If using the redis:// or rediss:// schemes, the path argument + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + + 2. If using the redis:// or rediss:// schemes, the path argument of the url, e.g. redis://localhost/0 - 3. A ``db`` keyword argument to this function. + + 3. A ``db`` keyword argument to this function. If none of these options are specified, the default db=0 is used. @@ -1283,7 +992,7 @@ class initializer. In the case of conflicting arguments, querystring def __init__( self, - connection_class: Type[Connection] = Connection, + connection_class: Type[AbstractConnection] = Connection, max_connections: Optional[int] = None, **connection_kwargs, ): @@ -1306,8 +1015,8 @@ def __init__( self._fork_lock = threading.Lock() self._lock = asyncio.Lock() self._created_connections: int - self._available_connections: List[Connection] - self._in_use_connections: Set[Connection] + self._available_connections: List[AbstractConnection] + self._in_use_connections: Set[AbstractConnection] self.reset() # lgtm [py/init-calls-subclass] self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) @@ -1430,7 +1139,7 @@ def make_connection(self): self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - async def release(self, connection: Connection): + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool""" self._checkpid() async with self._lock: @@ -1451,7 +1160,7 @@ async def release(self, connection: Connection): await connection.disconnect() return - def owns_connection(self, connection: Connection): + def owns_connection(self, connection: AbstractConnection): return connection.pid == self.pid async def disconnect(self, inuse_connections: bool = True): @@ -1465,7 +1174,7 @@ async def disconnect(self, inuse_connections: bool = True): self._checkpid() async with self._lock: if inuse_connections: - connections: Iterable[Connection] = chain( + connections: Iterable[AbstractConnection] = chain( self._available_connections, self._in_use_connections ) else: @@ -1523,14 +1232,14 @@ def __init__( self, max_connections: int = 50, timeout: Optional[int] = 20, - connection_class: Type[Connection] = Connection, + connection_class: Type[AbstractConnection] = Connection, queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout - self._connections: List[Connection] + self._connections: List[AbstractConnection] super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1586,7 +1295,7 @@ async def get_connection(self, command_name, *keys, **options): # self.timeout then raise a ``ConnectionError``. connection = None try: - async with async_timeout.timeout(self.timeout): + async with async_timeout(self.timeout): connection = await self.pool.get() except (asyncio.QueueEmpty, asyncio.TimeoutError): # Note that this is not caught by the redis client and will be @@ -1620,7 +1329,7 @@ async def get_connection(self, command_name, *keys, **options): return connection - async def release(self, connection: Connection): + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool.""" # Make sure we haven't changed process. self._checkpid() diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py deleted file mode 100644 index 5faf8f8c57..0000000000 --- a/redis/asyncio/parser.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union - -from redis.exceptions import RedisError, ResponseError - -if TYPE_CHECKING: - from redis.asyncio.cluster import ClusterNode - - -class CommandsParser: - """ - Parses Redis commands to get command keys. - - COMMAND output is used to determine key locations. - Commands that do not have a predefined key location are flagged with 'movablekeys', - and these commands' keys are determined by the command 'COMMAND GETKEYS'. - - NOTE: Due to a bug in redis<7.0, this does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this with EVAL or EVALSHA. - """ - - __slots__ = ("commands", "node") - - def __init__(self) -> None: - self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} - - async def initialize(self, node: Optional["ClusterNode"] = None) -> None: - if node: - self.node = node - - commands = await self.node.execute_command("COMMAND") - for cmd, command in commands.items(): - if "movablekeys" in command["flags"]: - commands[cmd] = -1 - elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: - commands[cmd] = 0 - elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: - commands[cmd] = 1 - self.commands = {cmd.upper(): command for cmd, command in commands.items()} - - # As soon as this PR is merged into Redis, we should reimplement - # our logic to use COMMAND INFO changes to determine the key positions - # https://github.com/redis/redis/pull/8324 - async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - if len(args) < 2: - # The command has no keys in it - return None - - try: - command = self.commands[args[0]] - except KeyError: - # try to split the command name and to take only the main command - # e.g. 'memory' for 'memory usage' - args = args[0].split() + list(args[1:]) - cmd_name = args[0].upper() - if cmd_name not in self.commands: - # We'll try to reinitialize the commands cache, if the engine - # version has changed, the commands may not be current - await self.initialize() - if cmd_name not in self.commands: - raise RedisError( - f"{cmd_name} command doesn't exist in Redis commands" - ) - - command = self.commands[cmd_name] - - if command == 1: - return (args[1],) - if command == 0: - return None - if command == -1: - return await self._get_moveable_keys(*args) - - last_key_pos = command["last_key_pos"] - if last_key_pos < 0: - last_key_pos = len(args) + last_key_pos - return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] - - async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - try: - keys = await self.node.execute_command("COMMAND GETKEYS", *args) - except ResponseError as e: - message = e.__str__() - if ( - "Invalid arguments" in message - or "The command has no key arguments" in message - ): - return None - else: - raise e - return keys diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index ec17886fc6..501e234c3c 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -67,11 +67,14 @@ async def read_response( self, disable_decoding: bool = False, timeout: Optional[float] = None, + *, + disconnect_on_error: Optional[float] = True, ): try: return await super().read_response( disable_decoding=disable_decoding, timeout=timeout, + disconnect_on_error=disconnect_on_error, ) except ReadOnlyError: if self.connection_pool.is_master: @@ -220,13 +223,13 @@ async def execute_command(self, *args, **kwargs): kwargs.pop("once") if once: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + else: tasks = [ asyncio.Task(sentinel.execute_command(*args, **kwargs)) for sentinel in self.sentinels ] await asyncio.gather(*tasks) - else: - await random.choice(self.sentinels).execute_command(*args, **kwargs) return True def __repr__(self): @@ -254,10 +257,12 @@ async def discover_master(self, service_name: str): Returns a pair (address, port) or raises MasterNotFoundError if no master is found. """ + collected_errors = list() for sentinel_no, sentinel in enumerate(self.sentinels): try: masters = await sentinel.sentinel_masters() - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + collected_errors.append(f"{sentinel} - {e!r}") continue state = masters.get(service_name) if state and self.check_master_state(state, service_name): @@ -267,7 +272,11 @@ async def discover_master(self, service_name: str): self.sentinels[0], ) return state["ip"], state["port"] - raise MasterNotFoundError(f"No master found for {service_name!r}") + + error_info = "" + if len(collected_errors) > 0: + error_info = f" : {', '.join(collected_errors)}" + raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}") def filter_slaves( self, slaves: Iterable[Mapping] diff --git a/redis/client.py b/redis/client.py index 1a9b96b83d..f695cef534 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,5 +1,4 @@ import copy -import datetime import re import threading import time @@ -7,6 +6,12 @@ from itertools import chain from typing import Optional +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -18,7 +23,6 @@ from redis.exceptions import ( ConnectionError, ExecAbortError, - ModuleError, PubSubError, RedisError, ResponseError, @@ -27,7 +31,13 @@ ) from redis.lock import Lock from redis.retry import Retry -from redis.utils import safe_str, str_if_bytes +from redis.utils import ( + HIREDIS_AVAILABLE, + _set_info_logger, + get_lib_version, + safe_str, + str_if_bytes, +) SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -36,21 +46,6 @@ NEVER_DECODE = "NEVER_DECODE" -def timestamp_to_datetime(response): - "Converts a unix timestamp to a Python datetime object" - if not response: - return None - try: - response = int(response) - except ValueError: - return None - return datetime.datetime.fromtimestamp(response) - - -def string_keys_to_dict(key_string, callback): - return dict.fromkeys(key_string.split(), callback) - - class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." @@ -78,771 +73,11 @@ def update(self, data): super().update(data) -def parse_debug_object(response): - "Parse the results of Redis's DEBUG OBJECT command into a Python dict" - # The 'type' of the object is the first item in the response, but isn't - # prefixed with a name - response = str_if_bytes(response) - response = "type:" + response - response = dict(kv.split(":") for kv in response.split()) - - # parse some expected int values from the string response - # note: this cmd isn't spec'd so these may not appear in all redis versions - int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") - for field in int_fields: - if field in response: - response[field] = int(response[field]) - - return response - - -def parse_object(response, infotype): - """Parse the results of an OBJECT command""" - if infotype in ("idletime", "refcount"): - return int_or_none(response) - return response - - -def parse_info(response): - """Parse the result of Redis's INFO command into a Python dict""" - info = {} - response = str_if_bytes(response) - - def get_value(value): - if "," not in value or "=" not in value: - try: - if "." in value: - return float(value) - else: - return int(value) - except ValueError: - return value - else: - sub_dict = {} - for item in value.split(","): - k, v = item.rsplit("=", 1) - sub_dict[k] = get_value(v) - return sub_dict - - for line in response.splitlines(): - if line and not line.startswith("#"): - if line.find(":") != -1: - # Split, the info fields keys and values. - # Note that the value may contain ':'. but the 'host:' - # pseudo-command is the only case where the key contains ':' - key, value = line.split(":", 1) - if key == "cmdstat_host": - key, value = line.rsplit(":", 1) - - if key == "module": - # Hardcode a list for key 'modules' since there could be - # multiple lines that started with 'module' - info.setdefault("modules", []).append(get_value(value)) - else: - info[key] = get_value(value) - else: - # if the line isn't splittable, append it to the "__raw__" key - info.setdefault("__raw__", []).append(line) - - return info - - -def parse_memory_stats(response, **kwargs): - """Parse the results of MEMORY STATS""" - stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) - for key, value in stats.items(): - if key.startswith("db."): - stats[key] = pairs_to_dict( - value, decode_keys=True, decode_string_values=True - ) - return stats - - -SENTINEL_STATE_TYPES = { - "can-failover-its-master": int, - "config-epoch": int, - "down-after-milliseconds": int, - "failover-timeout": int, - "info-refresh": int, - "last-hello-message": int, - "last-ok-ping-reply": int, - "last-ping-reply": int, - "last-ping-sent": int, - "master-link-down-time": int, - "master-port": int, - "num-other-sentinels": int, - "num-slaves": int, - "o-down-time": int, - "pending-commands": int, - "parallel-syncs": int, - "port": int, - "quorum": int, - "role-reported-time": int, - "s-down-time": int, - "slave-priority": int, - "slave-repl-offset": int, - "voted-leader-epoch": int, -} - - -def parse_sentinel_state(item): - result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) - flags = set(result["flags"].split(",")) - for name, flag in ( - ("is_master", "master"), - ("is_slave", "slave"), - ("is_sdown", "s_down"), - ("is_odown", "o_down"), - ("is_sentinel", "sentinel"), - ("is_disconnected", "disconnected"), - ("is_master_down", "master_down"), - ): - result[name] = flag in flags - return result - - -def parse_sentinel_master(response): - return parse_sentinel_state(map(str_if_bytes, response)) - - -def parse_sentinel_masters(response): - result = {} - for item in response: - state = parse_sentinel_state(map(str_if_bytes, item)) - result[state["name"]] = state - return result - - -def parse_sentinel_slaves_and_sentinels(response): - return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] - - -def parse_sentinel_get_master(response): - return response and (response[0], int(response[1])) or None - - -def pairs_to_dict(response, decode_keys=False, decode_string_values=False): - """Create a dict given a list of key/value pairs""" - if response is None: - return {} - if decode_keys or decode_string_values: - # the iter form is faster, but I don't know how to make that work - # with a str_if_bytes() map - keys = response[::2] - if decode_keys: - keys = map(str_if_bytes, keys) - values = response[1::2] - if decode_string_values: - values = map(str_if_bytes, values) - return dict(zip(keys, values)) - else: - it = iter(response) - return dict(zip(it, it)) - - -def pairs_to_dict_typed(response, type_info): - it = iter(response) - result = {} - for key, value in zip(it, it): - if key in type_info: - try: - value = type_info[key](value) - except Exception: - # if for some reason the value can't be coerced, just use - # the string value - pass - result[key] = value - return result - - -def zset_score_pairs(response, **options): - """ - If ``withscores`` is specified in the options, return the response as - a list of (value, score) pairs - """ - if not response or not options.get("withscores"): - return response - score_cast_func = options.get("score_cast_func", float) - it = iter(response) - return list(zip(it, map(score_cast_func, it))) - - -def sort_return_tuples(response, **options): - """ - If ``groups`` is specified, return the response as a list of - n-element tuples with n being the value found in options['groups'] - """ - if not response or not options.get("groups"): - return response - n = options["groups"] - return list(zip(*[response[i::n] for i in range(n)])) - - -def int_or_none(response): - if response is None: - return None - return int(response) - - -def parse_stream_list(response): - if response is None: - return None - data = [] - for r in response: - if r is not None: - data.append((r[0], pairs_to_dict(r[1]))) - else: - data.append((None, None)) - return data - - -def pairs_to_dict_with_str_keys(response): - return pairs_to_dict(response, decode_keys=True) - - -def parse_list_of_dicts(response): - return list(map(pairs_to_dict_with_str_keys, response)) - - -def parse_xclaim(response, **options): - if options.get("parse_justid", False): - return response - return parse_stream_list(response) - - -def parse_xautoclaim(response, **options): - if options.get("parse_justid", False): - return response[1] - response[1] = parse_stream_list(response[1]) - return response - - -def parse_xinfo_stream(response, **options): - data = pairs_to_dict(response, decode_keys=True) - if not options.get("full", False): - first = data["first-entry"] - if first is not None: - data["first-entry"] = (first[0], pairs_to_dict(first[1])) - last = data["last-entry"] - if last is not None: - data["last-entry"] = (last[0], pairs_to_dict(last[1])) - else: - data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - data["groups"] = [ - pairs_to_dict(group, decode_keys=True) for group in data["groups"] - ] - return data - - -def parse_xread(response): - if response is None: - return [] - return [[r[0], parse_stream_list(r[1])] for r in response] - - -def parse_xpending(response, **options): - if options.get("parse_detail", False): - return parse_xpending_range(response) - consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] - return { - "pending": response[0], - "min": response[1], - "max": response[2], - "consumers": consumers, - } - - -def parse_xpending_range(response): - k = ("message_id", "consumer", "time_since_delivered", "times_delivered") - return [dict(zip(k, r)) for r in response] - - -def float_or_none(response): - if response is None: - return None - return float(response) - - -def bool_ok(response): - return str_if_bytes(response) == "OK" - - -def parse_zadd(response, **options): - if response is None: - return None - if options.get("as_score"): - return float(response) - return int(response) - - -def parse_client_list(response, **options): - clients = [] - for c in str_if_bytes(response).splitlines(): - # Values might contain '=' - clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) - return clients - - -def parse_config_get(response, **options): - response = [str_if_bytes(i) if i is not None else None for i in response] - return response and pairs_to_dict(response) or {} - - -def parse_scan(response, **options): - cursor, r = response - return int(cursor), r - - -def parse_hscan(response, **options): - cursor, r = response - return int(cursor), r and pairs_to_dict(r) or {} - - -def parse_zscan(response, **options): - score_cast_func = options.get("score_cast_func", float) - cursor, r = response - it = iter(r) - return int(cursor), list(zip(it, map(score_cast_func, it))) - - -def parse_zmscore(response, **options): - # zmscore: list of scores (double precision floating point number) or nil - return [float(score) if score is not None else None for score in response] - - -def parse_slowlog_get(response, **options): - space = " " if options.get("decode_responses", False) else b" " - - def parse_item(item): - result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} - # Redis Enterprise injects another entry at index [3], which has - # the complexity info (i.e. the value N in case the command has - # an O(N) complexity) instead of the command. - if isinstance(item[3], list): - result["command"] = space.join(item[3]) - else: - result["complexity"] = item[3] - result["command"] = space.join(item[4]) - return result - - return [parse_item(item) for item in response] - - -def parse_stralgo(response, **options): - """ - Parse the response from `STRALGO` command. - Without modifiers the returned value is string. - When LEN is given the command returns the length of the result - (i.e integer). - When IDX is given the command returns a dictionary with the LCS - length and all the ranges in both the strings, start and end - offset for each string, where there are matches. - When WITHMATCHLEN is given, each array representing a match will - also have the length of the match at the beginning of the array. - """ - if options.get("len", False): - return int(response) - if options.get("idx", False): - if options.get("withmatchlen", False): - matches = [ - [(int(match[-1]))] + list(map(tuple, match[:-1])) - for match in response[1] - ] - else: - matches = [list(map(tuple, match)) for match in response[1]] - return { - str_if_bytes(response[0]): matches, - str_if_bytes(response[2]): int(response[3]), - } - return str_if_bytes(response) - - -def parse_cluster_info(response, **options): - response = str_if_bytes(response) - return dict(line.split(":") for line in response.splitlines() if line) - - -def _parse_node_line(line): - line_items = line.split(" ") - node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] - addr = addr.split("@")[0] - node_dict = { - "node_id": node_id, - "flags": flags, - "master_id": master_id, - "last_ping_sent": ping, - "last_pong_rcvd": pong, - "epoch": epoch, - "slots": [], - "migrations": [], - "connected": True if connected == "connected" else False, - } - if len(line_items) >= 9: - slots, migrations = _parse_slots(line_items[8:]) - node_dict["slots"], node_dict["migrations"] = slots, migrations - return addr, node_dict - - -def _parse_slots(slot_ranges): - slots, migrations = [], [] - for s_range in slot_ranges: - if "->-" in s_range: - slot_id, dst_node_id = s_range[1:-1].split("->-", 1) - migrations.append( - {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} - ) - elif "-<-" in s_range: - slot_id, src_node_id = s_range[1:-1].split("-<-", 1) - migrations.append( - {"slot": slot_id, "node_id": src_node_id, "state": "importing"} - ) - else: - s_range = [sl for sl in s_range.split("-")] - slots.append(s_range) - - return slots, migrations - - -def parse_cluster_nodes(response, **options): - """ - @see: https://redis.io/commands/cluster-nodes # string / bytes - @see: https://redis.io/commands/cluster-replicas # list of string / bytes - """ - if isinstance(response, (str, bytes)): - response = response.splitlines() - return dict(_parse_node_line(str_if_bytes(node)) for node in response) - - -def parse_geosearch_generic(response, **options): - """ - Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' - commands according to 'withdist', 'withhash' and 'withcoord' labels. - """ - if options["store"] or options["store_dist"]: - # `store` and `store_dist` cant be combined - # with other command arguments. - # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' - return response - - if type(response) != list: - response_list = [response] - else: - response_list = response - - if not options["withdist"] and not options["withcoord"] and not options["withhash"]: - # just a bunch of places - return response_list - - cast = { - "withdist": float, - "withcoord": lambda ll: (float(ll[0]), float(ll[1])), - "withhash": int, - } - - # zip all output results with each casting function to get - # the properly native Python value. - f = [lambda x: x] - f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] - return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] - - -def parse_command(response, **options): - commands = {} - for command in response: - cmd_dict = {} - cmd_name = str_if_bytes(command[0]) - cmd_dict["name"] = cmd_name - cmd_dict["arity"] = int(command[1]) - cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] - cmd_dict["first_key_pos"] = command[3] - cmd_dict["last_key_pos"] = command[4] - cmd_dict["step_count"] = command[5] - if len(command) > 7: - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] - commands[cmd_name] = cmd_dict - return commands - - -def parse_pubsub_numsub(response, **options): - return list(zip(response[0::2], response[1::2])) - - -def parse_client_kill(response, **options): - if isinstance(response, int): - return response - return str_if_bytes(response) == "OK" - - -def parse_acl_getuser(response, **options): - if response is None: - return None - data = pairs_to_dict(response, decode_keys=True) - - # convert everything but user-defined data in 'keys' to native strings - data["flags"] = list(map(str_if_bytes, data["flags"])) - data["passwords"] = list(map(str_if_bytes, data["passwords"])) - data["commands"] = str_if_bytes(data["commands"]) - if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): - data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) - if data["keys"] == [""]: - data["keys"] = [] - if "channels" in data: - if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): - data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) - if data["channels"] == [""]: - data["channels"] = [] - if "selectors" in data: - data["selectors"] = [ - list(map(str_if_bytes, selector)) for selector in data["selectors"] - ] - - # split 'commands' into separate 'categories' and 'commands' lists - commands, categories = [], [] - for command in data["commands"].split(" "): - if "@" in command: - categories.append(command) - else: - commands.append(command) - - data["commands"] = commands - data["categories"] = categories - data["enabled"] = "on" in data["flags"] - return data - - -def parse_acl_log(response, **options): - if response is None: - return None - if isinstance(response, list): - data = [] - for log in response: - log_data = pairs_to_dict(log, True, True) - client_info = log_data.get("client-info", "") - log_data["client-info"] = parse_client_info(client_info) - - # float() is lossy comparing to the "double" in C - log_data["age-seconds"] = float(log_data["age-seconds"]) - data.append(log_data) - else: - data = bool_ok(response) - return data - - -def parse_client_info(value): - """ - Parsing client-info in ACL Log in following format. - "key1=value1 key2=value2 key3=value3" - """ - client_info = {} - infos = str_if_bytes(value).split(" ") - for info in infos: - key, value = info.split("=") - client_info[key] = value - - # Those fields are defined as int in networking.c - for int_key in { - "id", - "age", - "idle", - "db", - "sub", - "psub", - "multi", - "qbuf", - "qbuf-free", - "obl", - "argv-mem", - "oll", - "omem", - "tot-mem", - }: - client_info[int_key] = int(client_info[int_key]) - return client_info - - -def parse_module_result(response): - if isinstance(response, ModuleError): - raise response - return True - - -def parse_set_result(response, **options): - """ - Handle SET result since GET argument is available since Redis 6.2. - Parsing SET result into: - - BOOL - - String when GET argument is used - """ - if options.get("get"): - # Redis will return a getCommand result. - # See `setGenericCommand` in t_string.c - return response - return response and str_if_bytes(response) == "OK" +class AbstractRedis: + pass -class AbstractRedis: - RESPONSE_CALLBACKS = { - **string_keys_to_dict( - "AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT " - "HEXISTS HMSET MOVE MSETNX PERSIST " - "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", - bool, - ), - **string_keys_to_dict( - "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " - "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " - "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " - "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " - "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", - int, - ), - **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), - **string_keys_to_dict( - # these return OK, or int if redis-server is >=1.3.4 - "LPUSH RPUSH", - lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", - ), - **string_keys_to_dict("SORT", sort_return_tuples), - **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), - **string_keys_to_dict( - "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE ASKING READONLY READWRITE " - "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", - bool_ok, - ), - **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), - **string_keys_to_dict( - "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - ), - **string_keys_to_dict( - "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " - "ZREVRANGE ZREVRANGEBYSCORE", - zset_score_pairs, - ), - **string_keys_to_dict( - "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None - ), - **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), - **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), - **string_keys_to_dict("XREAD XREADGROUP", parse_xread), - **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - "ACL CAT": lambda r: list(map(str_if_bytes, r)), - "ACL DELUSER": int, - "ACL GENPASS": str_if_bytes, - "ACL GETUSER": parse_acl_getuser, - "ACL HELP": lambda r: list(map(str_if_bytes, r)), - "ACL LIST": lambda r: list(map(str_if_bytes, r)), - "ACL LOAD": bool_ok, - "ACL LOG": parse_acl_log, - "ACL SAVE": bool_ok, - "ACL SETUSER": bool_ok, - "ACL USERS": lambda r: list(map(str_if_bytes, r)), - "ACL WHOAMI": str_if_bytes, - "CLIENT GETNAME": str_if_bytes, - "CLIENT ID": int, - "CLIENT KILL": parse_client_kill, - "CLIENT LIST": parse_client_list, - "CLIENT INFO": parse_client_info, - "CLIENT SETNAME": bool_ok, - "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, - "CLIENT PAUSE": bool_ok, - "CLIENT GETREDIR": int, - "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), - "CLUSTER ADDSLOTS": bool_ok, - "CLUSTER ADDSLOTSRANGE": bool_ok, - "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), - "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), - "CLUSTER DELSLOTS": bool_ok, - "CLUSTER DELSLOTSRANGE": bool_ok, - "CLUSTER FAILOVER": bool_ok, - "CLUSTER FORGET": bool_ok, - "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), - "CLUSTER INFO": parse_cluster_info, - "CLUSTER KEYSLOT": lambda x: int(x), - "CLUSTER MEET": bool_ok, - "CLUSTER NODES": parse_cluster_nodes, - "CLUSTER REPLICAS": parse_cluster_nodes, - "CLUSTER REPLICATE": bool_ok, - "CLUSTER RESET": bool_ok, - "CLUSTER SAVECONFIG": bool_ok, - "CLUSTER SET-CONFIG-EPOCH": bool_ok, - "CLUSTER SETSLOT": bool_ok, - "CLUSTER SLAVES": parse_cluster_nodes, - "COMMAND": parse_command, - "COMMAND COUNT": int, - "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), - "CONFIG GET": parse_config_get, - "CONFIG RESETSTAT": bool_ok, - "CONFIG SET": bool_ok, - "DEBUG OBJECT": parse_debug_object, - "FUNCTION DELETE": bool_ok, - "FUNCTION FLUSH": bool_ok, - "FUNCTION RESTORE": bool_ok, - "GEOHASH": lambda r: list(map(str_if_bytes, r)), - "GEOPOS": lambda r: list( - map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) - ), - "GEOSEARCH": parse_geosearch_generic, - "GEORADIUS": parse_geosearch_generic, - "GEORADIUSBYMEMBER": parse_geosearch_generic, - "HGETALL": lambda r: r and pairs_to_dict(r) or {}, - "HSCAN": parse_hscan, - "INFO": parse_info, - "LASTSAVE": timestamp_to_datetime, - "MEMORY PURGE": bool_ok, - "MEMORY STATS": parse_memory_stats, - "MEMORY USAGE": int_or_none, - "MODULE LOAD": parse_module_result, - "MODULE UNLOAD": parse_module_result, - "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], - "OBJECT": parse_object, - "PING": lambda r: str_if_bytes(r) == "PONG", - "QUIT": bool_ok, - "STRALGO": parse_stralgo, - "PUBSUB NUMSUB": parse_pubsub_numsub, - "RANDOMKEY": lambda r: r and r or None, - "RESET": str_if_bytes, - "SCAN": parse_scan, - "SCRIPT EXISTS": lambda r: list(map(bool, r)), - "SCRIPT FLUSH": bool_ok, - "SCRIPT KILL": bool_ok, - "SCRIPT LOAD": str_if_bytes, - "SENTINEL CKQUORUM": bool_ok, - "SENTINEL FAILOVER": bool_ok, - "SENTINEL FLUSHCONFIG": bool_ok, - "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, - "SENTINEL MASTER": parse_sentinel_master, - "SENTINEL MASTERS": parse_sentinel_masters, - "SENTINEL MONITOR": bool_ok, - "SENTINEL RESET": bool_ok, - "SENTINEL REMOVE": bool_ok, - "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, - "SENTINEL SET": bool_ok, - "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, - "SET": parse_set_result, - "SLOWLOG GET": parse_slowlog_get, - "SLOWLOG LEN": int, - "SLOWLOG RESET": bool_ok, - "SSCAN": parse_scan, - "TIME": lambda x: (int(x[0]), int(x[1])), - "XCLAIM": parse_xclaim, - "XAUTOCLAIM": parse_xautoclaim, - "XGROUP CREATE": bool_ok, - "XGROUP DELCONSUMER": int, - "XGROUP DESTROY": bool, - "XGROUP SETID": bool_ok, - "XINFO CONSUMERS": parse_list_of_dicts, - "XINFO GROUPS": parse_list_of_dicts, - "XINFO STREAM": parse_xinfo_stream, - "XPENDING": parse_xpending, - "ZADD": parse_zadd, - "ZSCAN": parse_zscan, - "ZMSCORE": parse_zmscore, - } - - -class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ Implementation of the Redis protocol. @@ -899,8 +134,12 @@ class initializer. In the case of conflicting arguments, querystring arguments always win. """ + single_connection_client = kwargs.pop("single_connection_client", False) connection_pool = ConnectionPool.from_url(url, **kwargs) - return cls(connection_pool=connection_pool) + return cls( + connection_pool=connection_pool, + single_connection_client=single_connection_client, + ) def __init__( self, @@ -938,10 +177,13 @@ def __init__( single_connection_client=False, health_check_interval=0, client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), username=None, retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Redis client. @@ -988,8 +230,11 @@ def __init__( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "lib_name": lib_name, + "lib_version": lib_version, "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, + "protocol": protocol, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -1035,7 +280,12 @@ def __init__( if single_connection_client: self.connection = self.connection_pool.get_connection("_") - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) + + if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + self.response_callbacks.update(_RedisCallbacksRESP3) + else: + self.response_callbacks.update(_RedisCallbacksRESP2) def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" @@ -1365,8 +615,8 @@ class PubSub: will be returned and it's safe to start listening again. """ - PUBLISH_MESSAGE_TYPES = ("message", "pmessage") - UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe") HEALTH_CHECK_MESSAGE = "redis-py-health-check" def __init__( @@ -1375,6 +625,7 @@ def __init__( shard_hint=None, ignore_subscribe_messages=False, encoder=None, + push_handler_func=None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -1384,6 +635,7 @@ def __init__( # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) @@ -1391,6 +643,8 @@ def __init__( self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: self.health_check_response = [b"pong", self.health_check_response_b] + if self.push_handler_func is None: + _set_info_logger() self.reset() def __enter__(self): @@ -1414,9 +668,11 @@ def reset(self): self.connection.clear_connect_callbacks() self.connection_pool.release(self.connection) self.connection = None - self.channels = {} self.health_check_response_counter = 0 + self.channels = {} self.pending_unsubscribe_channels = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() self.subscribed_event.clear() @@ -1431,16 +687,23 @@ def on_connect(self, connection): # before passing them to [p]subscribe. self.pending_unsubscribe_channels.clear() self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() if self.channels: - channels = {} - for k, v in self.channels.items(): - channels[self.encoder.decode(k, force=True)] = v + channels = { + self.encoder.decode(k, force=True): v for k, v in self.channels.items() + } self.subscribe(**channels) if self.patterns: - patterns = {} - for k, v in self.patterns.items(): - patterns[self.encoder.decode(k, force=True)] = v + patterns = { + self.encoder.decode(k, force=True): v for k, v in self.patterns.items() + } self.psubscribe(**patterns) + if self.shard_channels: + shard_channels = { + self.encoder.decode(k, force=True): v + for k, v in self.shard_channels.items() + } + self.ssubscribe(**shard_channels) @property def subscribed(self): @@ -1461,6 +724,8 @@ def execute_command(self, *args): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: @@ -1526,7 +791,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(disconnect_on_error=False, push_request=True) response = self._execute(conn, try_read) @@ -1647,6 +912,45 @@ def unsubscribe(self, *args): self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *args) + def ssubscribe(self, *args, target_node=None, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + return self.execute_command("SUNSUBSCRIBE", *args) + def listen(self): "Listen for messages on channels this client has been subscribed to" while self.subscribed: @@ -1681,12 +985,14 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0): return self.handle_message(response, ignore_subscribe_messages) return None + get_sharded_message = get_message + def ping(self, message=None): """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -1696,6 +1002,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): """ if response is None: return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1726,12 +1034,17 @@ def handle_message(self, response, ignore_subscribe_messages=False): if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) else: channel = response[1] if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) - if not self.channels and not self.patterns: + if not self.channels and not self.patterns and not self.shard_channels: # There are no subscriptions anymore, set subscribed_event flag # to false self.subscribed_event.clear() @@ -1740,6 +1053,8 @@ def handle_message(self, response, ignore_subscribe_messages=False): # if there's a message handler, invoke it if message_type == "pmessage": handler = self.patterns.get(message["pattern"], None) + elif message_type == "smessage": + handler = self.shard_channels.get(message["channel"], None) else: handler = self.channels.get(message["channel"], None) if handler: @@ -1760,6 +1075,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None): for pattern, handler in self.patterns.items(): if handler is None: raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + for s_channel, handler in self.shard_channels.items(): + if handler is None: + raise PubSubError( + f"Shard Channel: '{s_channel}' has no handler registered" + ) thread = PubSubWorkerThread( self, sleep_time, daemon=daemon, exception_handler=exception_handler @@ -2069,7 +1389,7 @@ def load_scripts(self): def _disconnect_raise_reset(self, conn, error): """ Close the connection, raise an exception if we were watching, - and raise an exception if retry_on_timeout is not set, + and raise an exception if TimeoutError is not part of retry_on_error, or the error is not a TimeoutError """ conn.disconnect() @@ -2080,11 +1400,13 @@ def _disconnect_raise_reset(self, conn, error): raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + # if TimeoutError is not part of retry_on_error, or the error + # is not a TimeoutError, raise it + if not ( + TimeoutError in conn.retry_on_error and isinstance(error, TimeoutError) + ): self.reset() - raise + raise error def execute(self, raise_on_error=True): """Execute all the commands in the current pipeline""" diff --git a/redis/cluster.py b/redis/cluster.py index 5f730a8596..1ffa5ff547 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,10 +6,13 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from redis._parsers import CommandsParser, Encoder +from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan -from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands -from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.client import CaseInsensitiveDict, PubSub, Redis +from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.commands.helpers import list_or_args +from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -31,6 +34,7 @@ from redis.lock import Lock from redis.retry import Retry from redis.utils import ( + HIREDIS_AVAILABLE, dict_merge, list_keys_to_dict, merge_result, @@ -97,6 +101,8 @@ def parse_cluster_shards(resp, **options): """ Parse CLUSTER SHARDS response. """ + if isinstance(resp[0], dict): + return resp shards = [] for x in resp: shard = {"slots": [], "nodes": []} @@ -113,6 +119,13 @@ def parse_cluster_shards(resp, **options): return shards +def parse_cluster_myshardid(resp, **options): + """ + Parse CLUSTER MYSHARDID response. + """ + return resp.decode("utf-8") + + PRIMARY = "primary" REPLICA = "replica" SLOT_ID = "slot-id" @@ -130,13 +143,17 @@ def parse_cluster_shards(resp, **options): "encoding_errors", "errors", "host", + "lib_name", + "lib_version", "max_connections", "nodes_flag", "redis_connect_func", "password", "port", + "queue_class", "retry", "retry_on_timeout", + "protocol", "socket_connect_timeout", "socket_keepalive", "socket_keepalive_options", @@ -210,6 +227,7 @@ class AbstractRedisCluster: "ACL WHOAMI", "AUTH", "CLIENT LIST", + "CLIENT SETINFO", "CLIENT SETNAME", "CLIENT GETNAME", "CONFIG SET", @@ -219,6 +237,8 @@ class AbstractRedisCluster: "PUBSUB CHANNELS", "PUBSUB NUMPAT", "PUBSUB NUMSUB", + "PUBSUB SHARDCHANNELS", + "PUBSUB SHARDNUMSUB", "PING", "INFO", "SHUTDOWN", @@ -229,6 +249,7 @@ class AbstractRedisCluster: "SLOWLOG LEN", "SLOWLOG RESET", "WAIT", + "WAITAOF", "SAVE", "MEMORY PURGE", "MEMORY MALLOC-STATS", @@ -244,7 +265,6 @@ class AbstractRedisCluster: "CLIENT INFO", "CLIENT KILL", "READONLY", - "READWRITE", "CLUSTER INFO", "CLUSTER MEET", "CLUSTER NODES", @@ -265,6 +285,11 @@ class AbstractRedisCluster: "READONLY", "READWRITE", "TIME", + "TFUNCTION LOAD", + "TFUNCTION DELETE", + "TFUNCTION LIST", + "TFCALL", + "TFCALLASYNC", "GRAPH.CONFIG", "LATENCY HISTORY", "LATENCY LATEST", @@ -281,6 +306,7 @@ class AbstractRedisCluster: "FUNCTION LIST", "FUNCTION LOAD", "FUNCTION RESTORE", + "REDISGEARS_2.REFRESHCLUSTER", "SCAN", "SCRIPT EXISTS", "SCRIPT FLUSH", @@ -340,14 +366,17 @@ class AbstractRedisCluster: CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { "CLUSTER SLOTS": parse_cluster_slots, "CLUSTER SHARDS": parse_cluster_shards, + "CLUSTER MYSHARDID": parse_cluster_myshardid, } RESULT_CALLBACKS = dict_merge( - list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub), + list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub), list_keys_to_dict( ["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values())) ), - list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result), + list_keys_to_dict( + ["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result + ), list_keys_to_dict( [ "PING", @@ -465,6 +494,7 @@ def __init__( read_from_replicas: bool = False, dynamic_startup_nodes: bool = True, url: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -513,6 +543,12 @@ def __init__( reinitialize_steps to 1. To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. + :param address_remap: + An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. :**kwargs: Extra arguments that will be sent into Redis instance when created @@ -588,12 +624,12 @@ def __init__( self.read_from_replicas = read_from_replicas self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps - self.nodes_manager = None self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, + address_remap=address_remap, **kwargs, ) @@ -1269,6 +1305,7 @@ def __init__( lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): self.nodes_cache = {} @@ -1280,6 +1317,7 @@ def __init__( self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class + self.address_remap = address_remap self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1502,6 +1540,7 @@ def initialize(self): if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache @@ -1518,6 +1557,7 @@ def initialize(self): for replica_node in replica_nodes: host = str_if_bytes(replica_node[0]) port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = self._get_or_create_cluster_node( host, port, REPLICA, tmp_nodes_cache @@ -1591,6 +1631,16 @@ def reset(self): # The read_load_balancer is None, do nothing pass + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPubSub(PubSub): """ @@ -1601,7 +1651,15 @@ class ClusterPubSub(PubSub): https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html """ - def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): + def __init__( + self, + redis_cluster, + node=None, + host=None, + port=None, + push_handler_func=None, + **kwargs, + ): """ When a pubsub instance is created without specifying a node, a single node will be transparently chosen for the pubsub connection on the @@ -1623,8 +1681,13 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs): else redis_cluster.get_redis_connection(self.node).connection_pool ) self.cluster = redis_cluster + self.node_pubsub_mapping = {} + self._pubsubs_generator = self._pubsubs_generator() super().__init__( - **kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, + **kwargs, ) def set_pubsub_node(self, cluster, node=None, host=None, port=None): @@ -1676,9 +1739,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port): f"Node {host}:{port} doesn't exist in the cluster" ) - def execute_command(self, *args, **kwargs): + def execute_command(self, *args): """ - Execute a publish/subscribe command. + Execute a subscribe/unsubscribe command. Taken code from redis-py and tweak to make it work within a cluster. """ @@ -1708,9 +1771,94 @@ def execute_command(self, *args, **kwargs): # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) + def _get_node_pubsub(self, node): + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + pubsub = node.redis_connection.pubsub( + push_handler_func=self.push_handler_func + ) + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + def _sharded_message_generator(self): + for _ in range(len(self.node_pubsub_mapping)): + pubsub = next(self._pubsubs_generator) + message = pubsub.get_message() + if message is not None: + return message + return None + + def _pubsubs_generator(self): + while True: + for pubsub in self.node_pubsub_mapping.values(): + yield pubsub + + def get_sharded_message( + self, ignore_subscribe_messages=False, timeout=0.0, target_node=None + ): + if target_node: + message = self.node_pubsub_mapping[target_node.name].get_message( + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + else: + message = self._sharded_message_generator() + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if self.node_pubsub_mapping[node.name].subscribed is False: + self.node_pubsub_mapping.pop(node.name) + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None + return message + + def ssubscribe(self, *args, **kwargs): + if args: + args = list_or_args(args[0], args[1:]) + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + for s_channel, handler in s_channels.items(): + node = self.cluster.get_node_from_key(s_channel) + pubsub = self._get_node_pubsub(node) + if handler: + pubsub.ssubscribe(**{s_channel: handler}) + else: + pubsub.ssubscribe(s_channel) + self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) + if pubsub.subscribed and not self.subscribed: + self.subscribed_event.set() + self.health_check_response_counter = 0 + + def sunsubscribe(self, *args): + if args: + args = list_or_args(args[0], args[1:]) + else: + args = self.shard_channels + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + p = self._get_node_pubsub(node) + p.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + p.pending_unsubscribe_shard_channels + ) + def get_redis_connection(self): """ Get the Redis connection of the pubsub connected node. @@ -1718,6 +1866,15 @@ def get_redis_connection(self): if self.node is not None: return self.node.redis_connection + def disconnect(self): + """ + Disconnect the pubsub connection. + """ + if self.connection: + self.connection.disconnect() + for pubsub in self.node_pubsub_mapping.values(): + pubsub.connection.disconnect() + class ClusterPipeline(RedisCluster): """ @@ -2136,6 +2293,17 @@ def delete(self, *names): return self.execute_command("DEL", names[0]) + def unlink(self, *names): + """ + "Unlink a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "unlinking multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("UNLINK", names[0]) + def block_pipeline_command(name: str) -> Callable[..., Any]: """ diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index f3f08286c8..a94d9764a6 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,7 +1,6 @@ from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args -from .parser import CommandsParser from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands from .sentinel import AsyncSentinelCommands, SentinelCommands @@ -10,7 +9,6 @@ "AsyncRedisClusterCommands", "AsyncRedisModuleCommands", "AsyncSentinelCommands", - "CommandsParser", "CoreCommands", "READ_COMMANDS", "RedisClusterCommands", diff --git a/redis/commands/bf/__init__.py b/redis/commands/bf/__init__.py index 4da060e995..959358f8e8 100644 --- a/redis/commands/bf/__init__.py +++ b/redis/commands/bf/__init__.py @@ -1,6 +1,6 @@ -from redis.client import bool_ok +from redis._parsers.helpers import bool_ok -from ..helpers import parse_to_list +from ..helpers import get_protocol_version, parse_to_list from .commands import * # noqa from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo @@ -91,20 +91,29 @@ class CMSBloom(CMSCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { CMS_INITBYDIM: bool_ok, CMS_INITBYPROB: bool_ok, # CMS_INCRBY: spaceHolder, # CMS_QUERY: spaceHolder, CMS_MERGE: bool_ok, + } + + _RESP2_MODULE_CALLBACKS = { CMS_INFO: CMSInfo, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CMSCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -112,21 +121,30 @@ class TOPKBloom(TOPKCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { TOPK_RESERVE: bool_ok, - TOPK_ADD: parse_to_list, - TOPK_INCRBY: parse_to_list, # TOPK_QUERY: spaceHolder, # TOPK_COUNT: spaceHolder, - TOPK_LIST: parse_to_list, + } + + _RESP2_MODULE_CALLBACKS = { + TOPK_ADD: parse_to_list, + TOPK_INCRBY: parse_to_list, TOPK_INFO: TopKInfo, + TOPK_LIST: parse_to_list, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TOPKCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -134,7 +152,7 @@ class CFBloom(CFCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { CF_RESERVE: bool_ok, # CF_ADD: spaceHolder, # CF_ADDNX: spaceHolder, @@ -145,14 +163,23 @@ def __init__(self, client, **kwargs): # CF_COUNT: spaceHolder, # CF_SCANDUMP: spaceHolder, # CF_LOADCHUNK: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { CF_INFO: CFInfo, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CFCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -160,28 +187,35 @@ class TDigestBloom(TDigestCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { TDIGEST_CREATE: bool_ok, # TDIGEST_RESET: bool_ok, # TDIGEST_ADD: spaceHolder, # TDIGEST_MERGE: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { + TDIGEST_BYRANK: parse_to_list, + TDIGEST_BYREVRANK: parse_to_list, TDIGEST_CDF: parse_to_list, - TDIGEST_QUANTILE: parse_to_list, + TDIGEST_INFO: TDigestInfo, TDIGEST_MIN: float, TDIGEST_MAX: float, TDIGEST_TRIMMED_MEAN: float, - TDIGEST_INFO: TDigestInfo, - TDIGEST_RANK: parse_to_list, - TDIGEST_REVRANK: parse_to_list, - TDIGEST_BYRANK: parse_to_list, - TDIGEST_BYREVRANK: parse_to_list, + TDIGEST_QUANTILE: parse_to_list, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TDigestCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -189,7 +223,7 @@ class BFBloom(BFCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { BF_RESERVE: bool_ok, # BF_ADD: spaceHolder, # BF_MADD: spaceHolder, @@ -199,12 +233,21 @@ def __init__(self, client, **kwargs): # BF_SCANDUMP: spaceHolder, # BF_LOADCHUNK: spaceHolder, # BF_CARD: spaceHolder, + } + + _RESP2_MODULE_CALLBACKS = { BF_INFO: BFInfo, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = BFCommands self.execute_command = client.execute_command - for k, v in MODULE_CALLBACKS.items(): + if get_protocol_version(self.client) in ["3", 3]: + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) diff --git a/redis/commands/bf/commands.py b/redis/commands/bf/commands.py index c45523c99b..447f844508 100644 --- a/redis/commands/bf/commands.py +++ b/redis/commands/bf/commands.py @@ -60,7 +60,6 @@ class BFCommands: """Bloom Filter commands.""" - # region Bloom Filter Functions def create(self, key, errorRate, capacity, expansion=None, noScale=None): """ Create a new Bloom Filter `key` with desired probability of false positives @@ -178,7 +177,6 @@ def card(self, key): class CFCommands: """Cuckoo Filter commands.""" - # region Cuckoo Filter Functions def create( self, key, capacity, expansion=None, bucket_size=None, max_iterations=None ): @@ -488,7 +486,6 @@ def byrevrank(self, key, rank, *ranks): class CMSCommands: """Count-Min Sketch Commands""" - # region Count-Min Sketch Functions def initbydim(self, key, width, depth): """ Initialize a Count-Min Sketch `key` to dimensions (`width`, `depth`) specified by user. diff --git a/redis/commands/bf/info.py b/redis/commands/bf/info.py index c526e6ca4c..e1f0208609 100644 --- a/redis/commands/bf/info.py +++ b/redis/commands/bf/info.py @@ -16,6 +16,15 @@ def __init__(self, args): self.insertedNum = response["Number of items inserted"] self.expansionRate = response["Expansion rate"] + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + class CFInfo(object): size = None @@ -38,6 +47,15 @@ def __init__(self, args): self.expansionRate = response["Expansion rate"] self.maxIteration = response["Max iterations"] + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) + class CMSInfo(object): width = None @@ -50,6 +68,9 @@ def __init__(self, args): self.depth = response["depth"] self.count = response["count"] + def __getitem__(self, item): + return getattr(self, item) + class TopKInfo(object): k = None @@ -64,6 +85,9 @@ def __init__(self, args): self.depth = response["depth"] self.decay = response["decay"] + def __getitem__(self, item): + return getattr(self, item) + class TDigestInfo(object): compression = None @@ -85,3 +109,12 @@ def __init__(self, args): self.unmerged_weight = response["Unmerged weight"] self.total_compressions = response["Total compressions"] self.memory_usage = response["Memory usage"] + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index a23a94a3d3..691cab3def 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -30,10 +30,12 @@ AsyncACLCommands, AsyncDataAccessCommands, AsyncFunctionCommands, + AsyncGearsCommands, AsyncManagementCommands, AsyncScriptCommands, DataAccessCommands, FunctionCommands, + GearsCommands, ManagementCommands, PubSubCommands, ResponseT, @@ -45,7 +47,6 @@ if TYPE_CHECKING: from redis.asyncio.cluster import TargetNodesT - # Not complete, but covers the major ones # https://redis.io/commands READ_COMMANDS = frozenset( @@ -634,6 +635,14 @@ def cluster_shards(self, target_nodes=None): """ return self.execute_command("CLUSTER SHARDS", target_nodes=target_nodes) + def cluster_myshardid(self, target_nodes=None): + """ + Returns the shard ID of the node. + + For more information see https://redis.io/commands/cluster-myshardid/ + """ + return self.execute_command("CLUSTER MYSHARDID", target_nodes=target_nodes) + def cluster_links(self, target_node: "TargetNodesT") -> ResponseT: """ Each node in a Redis Cluster maintains a pair of long-lived TCP link with each @@ -682,6 +691,12 @@ def readwrite(self, target_nodes: Optional["TargetNodesT"] = None) -> ResponseT: self.read_from_replicas = False return self.execute_command("READWRITE", target_nodes=target_nodes) + def gears_refresh_cluster(self, **kwargs) -> ResponseT: + """ + On an OSS cluster, before executing any gears function, you must call this command. # noqa + """ + return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs) + class AsyncClusterManagementCommands( ClusterManagementCommands, AsyncManagementCommands @@ -857,6 +872,7 @@ class RedisClusterCommands( ClusterDataAccessCommands, ScriptCommands, FunctionCommands, + GearsCommands, RedisModuleCommands, ): """ @@ -886,6 +902,7 @@ class AsyncRedisClusterCommands( AsyncClusterDataAccessCommands, AsyncScriptCommands, AsyncFunctionCommands, + AsyncGearsCommands, ): """ A class for all Redis Cluster commands diff --git a/redis/commands/core.py b/redis/commands/core.py index 3278c571f9..031781d75d 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -517,6 +517,7 @@ def client_list( """ Returns a list of currently connected clients. If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, replica, pubsub) :param client_id: optional. a list of client ids @@ -559,16 +560,17 @@ def client_reply( ) -> ResponseT: """ Enable and disable redis server replies. + ``reply`` Must be ON OFF or SKIP, - ON - The default most with server replies to commands - OFF - Disable server responses to commands - SKIP - Skip the response of the immediately following command. + ON - The default most with server replies to commands + OFF - Disable server responses to commands + SKIP - Skip the response of the immediately following command. Note: When setting OFF or SKIP replies, you will need a client object with a timeout specified in seconds, and will need to catch the TimeoutError. - The test_client_reply unit test illustrates this, and - conftest.py has a client with a timeout. + The test_client_reply unit test illustrates this, and + conftest.py has a client with a timeout. See https://redis.io/commands/client-reply """ @@ -706,6 +708,13 @@ def client_setname(self, name: str, **kwargs) -> ResponseT: """ return self.execute_command("CLIENT SETNAME", name, **kwargs) + def client_setinfo(self, attr: str, value: str, **kwargs) -> ResponseT: + """ + Sets the current connection library name or version + For mor information see https://redis.io/commands/client-setinfo + """ + return self.execute_command("CLIENT SETINFO", attr, value, **kwargs) + def client_unblock( self, client_id: int, error: bool = False, **kwargs ) -> ResponseT: @@ -724,19 +733,21 @@ def client_unblock( def client_pause(self, timeout: int, all: bool = True, **kwargs) -> ResponseT: """ - Suspend all the Redis clients for the specified amount of time - :param timeout: milliseconds to pause clients + Suspend all the Redis clients for the specified amount of time. + For more information see https://redis.io/commands/client-pause + + :param timeout: milliseconds to pause clients :param all: If true (default) all client commands are blocked. - otherwise, clients are only blocked if they attempt to execute - a write command. - For the WRITE mode, some commands have special behavior: - EVAL/EVALSHA: Will block client for all scripts. - PUBLISH: Will block client. - PFCOUNT: Will block client. - WAIT: Acknowledgments will be delayed, so this command will - appear blocked. + otherwise, clients are only blocked if they attempt to execute + a write command. + For the WRITE mode, some commands have special behavior: + EVAL/EVALSHA: Will block client for all scripts. + PUBLISH: Will block client. + PFCOUNT: Will block client. + WAIT: Acknowledgments will be delayed, so this command will + appear blocked. """ args = ["CLIENT PAUSE", str(timeout)] if not isinstance(timeout, int): @@ -761,6 +772,17 @@ def client_no_evict(self, mode: str) -> Union[Awaitable[str], str]: """ return self.execute_command("CLIENT NO-EVICT", mode) + def client_no_touch(self, mode: str) -> Union[Awaitable[str], str]: + """ + # The command controls whether commands sent by the client will alter + # the LRU/LFU of the keys they access. + # When turned on, the current client will not change LFU/LRU stats, + # unless it sends the TOUCH command. + + For more information see https://redis.io/commands/client-no-touch + """ + return self.execute_command("CLIENT NO-TOUCH", mode) + def command(self, **kwargs): """ Returns dict reply of details about all Redis commands. @@ -1204,9 +1226,11 @@ def quit(self, **kwargs) -> ResponseT: def replicaof(self, *args, **kwargs) -> ResponseT: """ Update the replication settings of a redis replica, on the fly. + Examples of valid arguments include: - NO ONE (set no replication) - host port (set to the host and port of a redis server) + + NO ONE (set no replication) + host port (set to the host and port of a redis server) For more information see https://redis.io/commands/replicaof """ @@ -1330,6 +1354,21 @@ def wait(self, num_replicas: int, timeout: int, **kwargs) -> ResponseT: """ return self.execute_command("WAIT", num_replicas, timeout, **kwargs) + def waitaof( + self, num_local: int, num_replicas: int, timeout: int, **kwargs + ) -> ResponseT: + """ + This command blocks the current client until all previous write + commands by that client are acknowledged as having been fsynced + to the AOF of the local Redis and/or at least the specified number + of replicas. + + For more information see https://redis.io/commands/waitaof + """ + return self.execute_command( + "WAITAOF", num_local, num_replicas, timeout, **kwargs + ) + def hello(self): """ This function throws a NotImplementedError since it is intentionally @@ -2667,7 +2706,11 @@ def llen(self, name: str) -> Union[Awaitable[int], int]: """ return self.execute_command("LLEN", name) - def lpop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + def lpop( + self, + name: str, + count: Optional[int] = None, + ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: """ Removes and returns the first elements of the list ``name``. @@ -2744,7 +2787,11 @@ def ltrim(self, name: str, start: int, end: int) -> Union[Awaitable[str], str]: """ return self.execute_command("LTRIM", name, start, end) - def rpop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: + def rpop( + self, + name: str, + count: Optional[int] = None, + ) -> Union[Awaitable[Union[str, List, None]], Union[str, List, None]]: """ Removes and returns the last elements of the list ``name``. @@ -3331,9 +3378,13 @@ def sinterstore( args = list_or_args(keys, args) return self.execute_command("SINTERSTORE", dest, *args) - def sismember(self, name: str, value: str) -> Union[Awaitable[bool], bool]: + def sismember( + self, name: str, value: str + ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: """ - Return a boolean indicating if ``value`` is a member of set ``name`` + Return whether ``value`` is a member of set ``name``: + - 1 if the value is a member of the set. + - 0 if the value is not a member of the set or if key does not exist. For more information see https://redis.io/commands/sismember """ @@ -3349,10 +3400,15 @@ def smembers(self, name: str) -> Union[Awaitable[Set], Set]: def smismember( self, name: str, values: List, *args: List - ) -> Union[Awaitable[List[bool]], List[bool]]: + ) -> Union[ + Awaitable[List[Union[Literal[0], Literal[1]]]], + List[Union[Literal[0], Literal[1]]], + ]: """ Return whether each value in ``values`` is a member of the set ``name`` - as a list of ``bool`` in the order of ``values`` + as a list of ``int`` in the order of ``values``: + - 1 if the value is a member of the set. + - 0 if the value is not a member of the set or if key does not exist. For more information see https://redis.io/commands/smismember """ @@ -3472,8 +3528,8 @@ def xadd( raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") if maxlen is not None: - if not isinstance(maxlen, int) or maxlen < 1: - raise DataError("XADD maxlen must be a positive integer") + if not isinstance(maxlen, int) or maxlen < 0: + raise DataError("XADD maxlen must be non-negative integer") pieces.append(b"MAXLEN") if approximate: pieces.append(b"~") @@ -3551,7 +3607,7 @@ def xclaim( groupname: GroupT, consumername: ConsumerT, min_idle_time: int, - message_ids: [List[StreamIdT], Tuple[StreamIdT]], + message_ids: Union[List[StreamIdT], Tuple[StreamIdT]], idle: Union[int, None] = None, time: Union[int, None] = None, retrycount: Union[int, None] = None, @@ -3560,27 +3616,37 @@ def xclaim( ) -> ResponseT: """ Changes the ownership of a pending message. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of milliseconds + message_ids: non-empty list or tuple of message IDs to claim + idle: optional. Set the idle time (last time it was delivered) of the - message in ms + message in ms + time: optional integer. This is the same as idle but instead of a - relative amount of milliseconds, it sets the idle time to a specific - Unix time (in milliseconds). + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + retrycount: optional integer. set the retry counter to the specified - value. This counter is incremented every time a message is delivered - again. + value. This counter is incremented every time a message is delivered + again. + force: optional boolean, false by default. Creates the pending message - entry in the PEL even if certain specified IDs are not already in the - PEL assigned to a different client. + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + justid: optional boolean, false by default. Return just an array of IDs - of messages successfully claimed, without returning the actual message + of messages successfully claimed, without returning the actual message - For more information see https://redis.io/commands/xclaim + For more information see https://redis.io/commands/xclaim """ if not isinstance(min_idle_time, int) or min_idle_time < 0: raise DataError("XCLAIM min_idle_time must be a non negative integer") @@ -3832,11 +3898,15 @@ def xrange( ) -> ResponseT: """ Read stream values within an interval. + name: name of the stream. + start: first stream ID. defaults to '-', meaning the earliest available. + finish: last stream ID. defaults to '+', meaning the latest available. + count: if set, only return this many items, beginning with the earliest available. @@ -3859,10 +3929,13 @@ def xread( ) -> ResponseT: """ Block and monitor multiple streams for new data. + streams: a dict of stream names to stream IDs, where IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the earliest available. + block: number of milliseconds to wait, if nothing already present. For more information see https://redis.io/commands/xread @@ -3897,12 +3970,17 @@ def xreadgroup( ) -> ResponseT: """ Read from a stream via a consumer group. + groupname: name of the consumer group. + consumername: name of the requesting consumer. + streams: a dict of stream names to stream IDs, where IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the earliest available. + block: number of milliseconds to wait, if nothing already present. noack: do not add messages to the PEL @@ -3937,11 +4015,15 @@ def xrevrange( ) -> ResponseT: """ Read stream values within an interval, in reverse order. + name: name of the stream + start: first stream ID. defaults to '+', meaning the latest available. + finish: last stream ID. defaults to '-', meaning the earliest available. + count: if set, only return this many items, beginning with the latest available. @@ -4630,13 +4712,22 @@ def zrevrangebyscore( options = {"withscores": withscores, "score_cast_func": score_cast_func} return self.execute_command(*pieces, **options) - def zrank(self, name: KeyT, value: EncodableT) -> ResponseT: + def zrank( + self, + name: KeyT, + value: EncodableT, + withscore: bool = False, + ) -> ResponseT: """ Returns a 0-based value indicating the rank of ``value`` in sorted set - ``name`` + ``name``. + The optional WITHSCORE argument supplements the command's + reply with the score of the element returned. For more information see https://redis.io/commands/zrank """ + if withscore: + return self.execute_command("ZRANK", name, value, "WITHSCORE") return self.execute_command("ZRANK", name, value) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: @@ -4680,13 +4771,22 @@ def zremrangebyscore( """ return self.execute_command("ZREMRANGEBYSCORE", name, min, max) - def zrevrank(self, name: KeyT, value: EncodableT) -> ResponseT: + def zrevrank( + self, + name: KeyT, + value: EncodableT, + withscore: bool = False, + ) -> ResponseT: """ Returns a 0-based value indicating the descending rank of - ``value`` in sorted set ``name`` + ``value`` in sorted set ``name``. + The optional ``withscore`` argument supplements the command's + reply with the score of the element returned. For more information see https://redis.io/commands/zrevrank """ + if withscore: + return self.execute_command("ZREVRANK", name, value, "WITHSCORE") return self.execute_command("ZREVRANK", name, value) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: @@ -5090,6 +5190,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT """ return self.execute_command("PUBLISH", channel, message, **kwargs) + def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT: + """ + Posts a message to the given shard channel. + Returns the number of clients that received the message + + For more information see https://redis.io/commands/spublish + """ + return self.execute_command("SPUBLISH", shard_channel, message) + def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ Return a list of channels that have at least one subscriber @@ -5098,6 +5207,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs) + def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT: + """ + Return a list of shard_channels that have at least one subscriber + + For more information see https://redis.io/commands/pubsub-shardchannels + """ + return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs) + def pubsub_numpat(self, **kwargs) -> ResponseT: """ Returns the number of subscriptions to patterns @@ -5115,6 +5232,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT: """ return self.execute_command("PUBSUB NUMSUB", *args, **kwargs) + def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT: + """ + Return a list of (shard_channel, number of subscribers) tuples + for each channel given in ``*args`` + + For more information see https://redis.io/commands/pubsub-shardnumsub + """ + return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs) + AsyncPubSubCommands = PubSubCommands @@ -5214,8 +5340,10 @@ def script_flush( self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None ) -> ResponseT: """Flush all scripts from the script cache. + ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. + For more information see https://redis.io/commands/script-flush """ @@ -5528,11 +5656,14 @@ def geosearch( area specified by a given shape. This command extends the GEORADIUS command, so in addition to searching within circular areas, it supports searching within rectangular areas. + This command should be used in place of the deprecated GEORADIUS and GEORADIUSBYMEMBER commands. + ``member`` Use the position of the given existing member in the sorted set. Can't be given with ``longitude`` and ``latitude``. + ``longitude`` and ``latitude`` Use the position given by this coordinates. Can't be given with ``member`` ``radius`` Similar to GEORADIUS, search inside circular @@ -5541,17 +5672,23 @@ def geosearch( ``height`` and ``width`` Search inside an axis-aligned rectangle, determined by the given height and width. Can't be given with ``radius`` + ``unit`` must be one of the following : m, km, mi, ft. `m` for meters (the default value), `km` for kilometers, `mi` for miles and `ft` for feet. + ``sort`` indicates to return the places in a sorted way, ASC for nearest to furthest and DESC for furthest to nearest. + ``count`` limit the results to the first count matching items. + ``any`` is set to True, the command will return as soon as enough matches are found. Can't be provided without ``count`` + ``withdist`` indicates to return the distances of each place. ``withcoord`` indicates to return the latitude and longitude of each place. + ``withhash`` indicates to return the geohash string of each place. For more information see https://redis.io/commands/geosearch @@ -5975,6 +6112,131 @@ def function_stats(self) -> Union[Awaitable[List], List]: AsyncFunctionCommands = FunctionCommands +class GearsCommands: + def tfunction_load( + self, lib_code: str, replace: bool = False, config: Union[str, None] = None + ) -> ResponseT: + """ + Load a new library to RedisGears. + + ``lib_code`` - the library code. + ``config`` - a string representation of a JSON object + that will be provided to the library on load time, + for more information refer to + https://github.com/RedisGears/RedisGears/blob/master/docs/function_advance_topics.md#library-configuration + ``replace`` - an optional argument, instructs RedisGears to replace the + function if its already exists + + For more information see https://redis.io/commands/tfunction-load/ + """ + pieces = [] + if replace: + pieces.append("REPLACE") + if config is not None: + pieces.extend(["CONFIG", config]) + pieces.append(lib_code) + return self.execute_command("TFUNCTION LOAD", *pieces) + + def tfunction_delete(self, lib_name: str) -> ResponseT: + """ + Delete a library from RedisGears. + + ``lib_name`` the library name to delete. + + For more information see https://redis.io/commands/tfunction-delete/ + """ + return self.execute_command("TFUNCTION DELETE", lib_name) + + def tfunction_list( + self, + with_code: bool = False, + verbose: int = 0, + lib_name: Union[str, None] = None, + ) -> ResponseT: + """ + List the functions with additional information about each function. + + ``with_code`` Show libraries code. + ``verbose`` output verbosity level, higher number will increase verbosity level + ``lib_name`` specifying a library name (can be used multiple times to show multiple libraries in a single command) # noqa + + For more information see https://redis.io/commands/tfunction-list/ + """ + pieces = [] + if with_code: + pieces.append("WITHCODE") + if verbose >= 1 and verbose <= 3: + pieces.append("v" * verbose) + else: + raise DataError("verbose can be 1, 2 or 3") + if lib_name is not None: + pieces.append("LIBRARY") + pieces.append(lib_name) + + return self.execute_command("TFUNCTION LIST", *pieces) + + def _tfcall( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + _async: bool = False, + *args: List, + ) -> ResponseT: + pieces = [f"{lib_name}.{func_name}"] + if keys is not None: + pieces.append(len(keys)) + pieces.extend(keys) + else: + pieces.append(0) + if args is not None: + pieces.extend(args) + if _async: + return self.execute_command("TFCALLASYNC", *pieces) + return self.execute_command("TFCALL", *pieces) + + def tfcall( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + *args: List, + ) -> ResponseT: + """ + Invoke a function. + + ``lib_name`` - the library name contains the function. + ``func_name`` - the function name to run. + ``keys`` - the keys that will be touched by the function. + ``args`` - Additional argument to pass to the function. + + For more information see https://redis.io/commands/tfcall/ + """ + return self._tfcall(lib_name, func_name, keys, False, *args) + + def tfcall_async( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + *args: List, + ) -> ResponseT: + """ + Invoke an async function (coroutine). + + ``lib_name`` - the library name contains the function. + ``func_name`` - the function name to run. + ``keys`` - the keys that will be touched by the function. + ``args`` - Additional argument to pass to the function. + + For more information see https://redis.io/commands/tfcall/ + """ + return self._tfcall(lib_name, func_name, keys, True, *args) + + +AsyncGearsCommands = GearsCommands + + class DataAccessCommands( BasicKeyCommands, HyperlogCommands, @@ -6018,6 +6280,7 @@ class CoreCommands( PubSubCommands, ScriptCommands, FunctionCommands, + GearsCommands, ): """ A class containing all of the implemented redis commands. This class is @@ -6034,6 +6297,7 @@ class AsyncCoreCommands( AsyncPubSubCommands, AsyncScriptCommands, AsyncFunctionCommands, + AsyncGearsCommands, ): """ A class containing all of the implemented redis commands. This class is diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index b65cd1a933..324d981d66 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -3,6 +3,7 @@ import string from typing import List, Tuple +import redis from redis.typing import KeysT, KeyT @@ -156,3 +157,10 @@ def stringify_param_value(value): return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa else: return str(value) + + +def get_protocol_version(client): + if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis): + return client.connection_pool.connection_kwargs.get("protocol") + elif isinstance(client, redis.cluster.AbstractRedisCluster): + return client.nodes_manager.connection_kwargs.get("protocol") diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index 7d55023e1e..01077e6b88 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -2,7 +2,7 @@ import redis -from ..helpers import nativestr +from ..helpers import get_protocol_version, nativestr from .commands import JSONCommands from .decoders import bulk_of_jsons, decode_list @@ -31,35 +31,49 @@ def __init__( :type json.JSONEncoder: An instance of json.JSONEncoder """ # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - "JSON.CLEAR": int, - "JSON.DEL": int, - "JSON.FORGET": int, + self._MODULE_CALLBACKS = { + "JSON.ARRPOP": self._decode, + "JSON.DEBUG": self._decode, "JSON.GET": self._decode, + "JSON.MERGE": lambda r: r and nativestr(r) == "OK", "JSON.MGET": bulk_of_jsons(self._decode), + "JSON.MSET": lambda r: r and nativestr(r) == "OK", + "JSON.RESP": self._decode, "JSON.SET": lambda r: r and nativestr(r) == "OK", - "JSON.NUMINCRBY": self._decode, - "JSON.NUMMULTBY": self._decode, "JSON.TOGGLE": self._decode, - "JSON.STRAPPEND": self._decode, - "JSON.STRLEN": self._decode, + } + + _RESP2_MODULE_CALLBACKS = { "JSON.ARRAPPEND": self._decode, "JSON.ARRINDEX": self._decode, "JSON.ARRINSERT": self._decode, "JSON.ARRLEN": self._decode, - "JSON.ARRPOP": self._decode, "JSON.ARRTRIM": self._decode, - "JSON.OBJLEN": self._decode, + "JSON.CLEAR": int, + "JSON.DEL": int, + "JSON.FORGET": int, + "JSON.GET": self._decode, + "JSON.NUMINCRBY": self._decode, + "JSON.NUMMULTBY": self._decode, "JSON.OBJKEYS": self._decode, - "JSON.RESP": self._decode, - "JSON.DEBUG": self._decode, + "JSON.STRAPPEND": self._decode, + "JSON.OBJLEN": self._decode, + "JSON.STRLEN": self._decode, + "JSON.TOGGLE": self._decode, } + _RESP3_MODULE_CALLBACKS = {} + self.client = client self.execute_command = client.execute_command self.MODULE_VERSION = version - for key, value in self.MODULE_CALLBACKS.items(): + if get_protocol_version(self.client) in ["3", 3]: + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for key, value in self._MODULE_CALLBACKS.items(): self.client.set_response_callback(key, value) self.__encoder__ = encoder @@ -115,7 +129,7 @@ def pipeline(self, transaction=True, shard_hint=None): else: p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 7fd4039203..3abe155796 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -1,6 +1,6 @@ import os from json import JSONDecodeError, loads -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from redis.exceptions import DataError from redis.utils import deprecated_function @@ -31,8 +31,8 @@ def arrindex( name: str, path: str, scalar: int, - start: Optional[int] = 0, - stop: Optional[int] = -1, + start: Optional[int] = None, + stop: Optional[int] = None, ) -> List[Union[int, None]]: """ Return the index of ``scalar`` in the JSON array under ``path`` at key @@ -43,9 +43,13 @@ def arrindex( For more information see `JSON.ARRINDEX `_. """ # noqa - return self.execute_command( - "JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop - ) + pieces = [name, str(path), self._encode(scalar)] + if start is not None: + pieces.append(start) + if stop is not None: + pieces.append(stop) + + return self.execute_command("JSON.ARRINDEX", *pieces) def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] @@ -249,6 +253,46 @@ def set( pieces.append("XX") return self.execute_command("JSON.SET", *pieces) + def mset(self, triplets: List[Tuple[str, str, JsonType]]) -> Optional[str]: + """ + Set the JSON value at key ``name`` under the ``path`` to ``obj`` + for one or more keys. + + ``triplets`` is a list of one or more triplets of key, path, value. + + For the purpose of using this within a pipeline, this command is also + aliased to JSON.MSET. + + For more information see `JSON.MSET `_. + """ + pieces = [] + for triplet in triplets: + pieces.extend([triplet[0], str(triplet[1]), self._encode(triplet[2])]) + return self.execute_command("JSON.MSET", *pieces) + + def merge( + self, + name: str, + path: str, + obj: JsonType, + decode_keys: Optional[bool] = False, + ) -> Optional[str]: + """ + Merges a given JSON value into matching paths. Consequently, JSON values + at matching paths are updated, deleted, or expanded with new children + + ``decode_keys`` If set to True, the keys of ``obj`` will be decoded + with utf-8. + + For more information see `JSON.MERGE `_. + """ + if decode_keys: + obj = decode_dict_keys(obj) + + pieces = [name, str(path), self._encode(obj)] + + return self.execute_command("JSON.MERGE", *pieces) + def set_file( self, name: str, diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 70e9c279e5..e635f91e99 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -1,7 +1,17 @@ import redis from ...asyncio.client import Pipeline as AsyncioPipeline -from .commands import AsyncSearchCommands, SearchCommands +from .commands import ( + AGGREGATE_CMD, + CONFIG_CMD, + INFO_CMD, + PROFILE_CMD, + SEARCH_CMD, + SPELLCHECK_CMD, + SYNDUMP_CMD, + AsyncSearchCommands, + SearchCommands, +) class Search(SearchCommands): @@ -85,11 +95,20 @@ def __init__(self, client, index_name="idx"): If conn is not None, we employ an already existing redis connection """ - self.MODULE_CALLBACKS = {} + self._MODULE_CALLBACKS = {} self.client = client self.index_name = index_name self.execute_command = client.execute_command self._pipeline = client.pipeline + self._RESP2_MODULE_CALLBACKS = { + INFO_CMD: self._parse_info, + SEARCH_CMD: self._parse_search, + AGGREGATE_CMD: self._parse_aggregate, + PROFILE_CMD: self._parse_profile, + SPELLCHECK_CMD: self._parse_spellcheck, + CONFIG_CMD: self._parse_config_get, + SYNDUMP_CMD: self._parse_syndump, + } def pipeline(self, transaction=True, shard_hint=None): """Creates a pipeline for the SEARCH module, that can be used for executing @@ -97,7 +116,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) @@ -155,7 +174,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ p = AsyncPipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 3bd7d47aa8..87b572195c 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -63,6 +63,86 @@ class SearchCommands: """Search commands.""" + def _parse_results(self, cmd, res, **kwargs): + if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: + return res + else: + return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) + + def _parse_info(self, res, **kwargs): + it = map(to_string, res) + return dict(zip(it, it)) + + def _parse_search(self, res, **kwargs): + return Result( + res, + not kwargs["query"]._no_content, + duration=kwargs["duration"], + has_payload=kwargs["query"]._with_payloads, + with_scores=kwargs["query"]._with_scores, + ) + + def _parse_aggregate(self, res, **kwargs): + return self._get_aggregate_result(res, kwargs["query"], kwargs["has_cursor"]) + + def _parse_profile(self, res, **kwargs): + query = kwargs["query"] + if isinstance(query, AggregateRequest): + result = self._get_aggregate_result(res[0], query, query._cursor) + else: + result = Result( + res[0], + not query._no_content, + duration=kwargs["duration"], + has_payload=query._with_payloads, + with_scores=query._with_scores, + ) + + return result, parse_to_dict(res[1]) + + def _parse_spellcheck(self, res, **kwargs): + corrections = {} + if res == 0: + return corrections + + for _correction in res: + if isinstance(_correction, int) and _correction == 0: + continue + + if len(_correction) != 3: + continue + if not _correction[2]: + continue + if not _correction[2][0]: + continue + + # For spellcheck output + # 1) 1) "TERM" + # 2) "{term1}" + # 3) 1) 1) "{score1}" + # 2) "{suggestion1}" + # 2) 1) "{score2}" + # 2) "{suggestion2}" + # + # Following dictionary will be made + # corrections = { + # '{term1}': [ + # {'score': '{score1}', 'suggestion': '{suggestion1}'}, + # {'score': '{score2}', 'suggestion': '{suggestion2}'} + # ] + # } + corrections[_correction[1]] = [ + {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] + ] + + return corrections + + def _parse_config_get(self, res, **kwargs): + return {kvs[0]: kvs[1] for kvs in res} if res else {} + + def _parse_syndump(self, res, **kwargs): + return {res[i]: res[i + 1] for i in range(0, len(res), 2)} + def batch_indexer(self, chunk_size=100): """ Create a new batch indexer from the client with a given chunk size @@ -368,11 +448,10 @@ def info(self): """ res = self.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) + return self._parse_results(INFO_CMD, res) def get_params_args( - self, query_params: Union[Dict[str, Union[str, int, float]], None] + self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] ): if query_params is None: return [] @@ -385,7 +464,9 @@ def get_params_args( args.append(value) return args - def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]]): + def _mk_query_args( + self, query, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + ): args = [self.index_name] if isinstance(query, str): @@ -402,7 +483,7 @@ def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]]) def search( self, query: Union[str, Query], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Union[Dict[str, Union[str, int, float, bytes]], None] = None, ): """ Search the index for a given query, and return a result of documents @@ -422,12 +503,8 @@ def search( if isinstance(res, Pipeline): return res - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) def explain( @@ -473,7 +550,9 @@ def aggregate( cmd += self.get_params_args(query_params) raw = self.execute_command(*cmd) - return self._get_aggregate_result(raw, query, has_cursor) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) def _get_aggregate_result(self, raw, query, has_cursor): if has_cursor: @@ -531,18 +610,9 @@ def profile( res = self.execute_command(*cmd) - if isinstance(query, AggregateRequest): - result = self._get_aggregate_result(res[0], query, query._cursor) - else: - result = Result( - res[0], - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, - ) - - return result, parse_to_dict(res[1]) + return self._parse_results( + PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0 + ) def spellcheck(self, query, distance=None, include=None, exclude=None): """ @@ -568,43 +638,9 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): if exclude: cmd.extend(["TERMS", "EXCLUDE", exclude]) - raw = self.execute_command(*cmd) - - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - # For spellcheck output - # 1) 1) "TERM" - # 2) "{term1}" - # 3) 1) 1) "{score1}" - # 2) "{suggestion1}" - # 2) 1) "{score2}" - # 2) "{suggestion2}" - # - # Following dictionary will be made - # corrections = { - # '{term1}': [ - # {'score': '{score1}', 'suggestion': '{suggestion1}'}, - # {'score': '{score2}', 'suggestion': '{suggestion2}'} - # ] - # } - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] + res = self.execute_command(*cmd) - return corrections + return self._parse_results(SPELLCHECK_CMD, res) def dict_add(self, name, *terms): """Adds terms to a dictionary. @@ -670,12 +706,8 @@ def config_get(self, option): For more information see `FT.CONFIG GET `_. """ # noqa cmd = [CONFIG_CMD, "GET", option] - res = {} - raw = self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res + res = self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) def tagvals(self, tagfield): """ @@ -810,12 +842,12 @@ def sugget( if with_payloads: args.append(WITHPAYLOADS) - ret = self.execute_command(*args) + res = self.execute_command(*args) results = [] - if not ret: + if not res: return results - parser = SuggestionParser(with_scores, with_payloads, ret) + parser = SuggestionParser(with_scores, with_payloads, res) return [s for s in parser] def synupdate(self, groupid, skipinitial=False, *terms): @@ -851,8 +883,8 @@ def syndump(self): For more information see `FT.SYNDUMP `_. """ # noqa - raw = self.execute_command(SYNDUMP_CMD, self.index_name) - return {raw[i]: raw[i + 1] for i in range(0, len(raw), 2)} + res = self.execute_command(SYNDUMP_CMD, self.index_name) + return self._parse_results(SYNDUMP_CMD, res) class AsyncSearchCommands(SearchCommands): @@ -865,8 +897,7 @@ async def info(self): """ res = await self.execute_command(INFO_CMD, self.index_name) - it = map(to_string, res) - return dict(zip(it, it)) + return self._parse_results(INFO_CMD, res) async def search( self, @@ -891,12 +922,8 @@ async def search( if isinstance(res, Pipeline): return res - return Result( - res, - not query._no_content, - duration=(time.time() - st) * 1000.0, - has_payload=query._with_payloads, - with_scores=query._with_scores, + return self._parse_results( + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) async def aggregate( @@ -927,7 +954,9 @@ async def aggregate( cmd += self.get_params_args(query_params) raw = await self.execute_command(*cmd) - return self._get_aggregate_result(raw, query, has_cursor) + return self._parse_results( + AGGREGATE_CMD, raw, query=query, has_cursor=has_cursor + ) async def spellcheck(self, query, distance=None, include=None, exclude=None): """ @@ -953,28 +982,9 @@ async def spellcheck(self, query, distance=None, include=None, exclude=None): if exclude: cmd.extend(["TERMS", "EXCLUDE", exclude]) - raw = await self.execute_command(*cmd) + res = await self.execute_command(*cmd) - corrections = {} - if raw == 0: - return corrections - - for _correction in raw: - if isinstance(_correction, int) and _correction == 0: - continue - - if len(_correction) != 3: - continue - if not _correction[2]: - continue - if not _correction[2][0]: - continue - - corrections[_correction[1]] = [ - {"score": _item[0], "suggestion": _item[1]} for _item in _correction[2] - ] - - return corrections + return self._parse_results(SPELLCHECK_CMD, res) async def config_set(self, option, value): """Set runtime configuration option. @@ -1001,11 +1011,8 @@ async def config_get(self, option): """ # noqa cmd = [CONFIG_CMD, "GET", option] res = {} - raw = await self.execute_command(*cmd) - if raw: - for kvs in raw: - res[kvs[0]] = kvs[1] - return res + res = await self.execute_command(*cmd) + return self._parse_results(CONFIG_CMD, res) async def load_document(self, id): """ diff --git a/redis/commands/search/document.py b/redis/commands/search/document.py index 5b3050545a..47534ec248 100644 --- a/redis/commands/search/document.py +++ b/redis/commands/search/document.py @@ -11,3 +11,7 @@ def __init__(self, id, payload=None, **fields): def __repr__(self): return f"Document {self.__dict__}" + + def __getitem__(self, item): + value = getattr(self, item) + return value diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 4a6886f237..4188b93d70 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -1,6 +1,7 @@ import redis +from redis._parsers.helpers import bool_ok -from ..helpers import parse_to_list +from ..helpers import get_protocol_version, parse_to_list from .commands import ( ALTER_CMD, CREATE_CMD, @@ -32,27 +33,36 @@ class TimeSeries(TimeSeriesCommands): def __init__(self, client=None, **kwargs): """Create a new RedisTimeSeries client.""" # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - CREATE_CMD: redis.client.bool_ok, - ALTER_CMD: redis.client.bool_ok, - CREATERULE_CMD: redis.client.bool_ok, + self._MODULE_CALLBACKS = { + ALTER_CMD: bool_ok, + CREATE_CMD: bool_ok, + CREATERULE_CMD: bool_ok, + DELETERULE_CMD: bool_ok, + } + + _RESP2_MODULE_CALLBACKS = { DEL_CMD: int, - DELETERULE_CMD: redis.client.bool_ok, - RANGE_CMD: parse_range, - REVRANGE_CMD: parse_range, - MRANGE_CMD: parse_m_range, - MREVRANGE_CMD: parse_m_range, GET_CMD: parse_get, - MGET_CMD: parse_m_get, INFO_CMD: TSInfo, + MGET_CMD: parse_m_get, + MRANGE_CMD: parse_m_range, + MREVRANGE_CMD: parse_m_range, + RANGE_CMD: parse_range, + REVRANGE_CMD: parse_range, QUERYINDEX_CMD: parse_to_list, } + _RESP3_MODULE_CALLBACKS = {} self.client = client self.execute_command = client.execute_command - for key, value in self.MODULE_CALLBACKS.items(): - self.client.set_response_callback(key, value) + if get_protocol_version(self.client) in ["3", 3]: + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) + + for k, v in self._MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) def pipeline(self, transaction=True, shard_hint=None): """Creates a pipeline for the TimeSeries module, that can be used @@ -83,7 +93,7 @@ def pipeline(self, transaction=True, shard_hint=None): else: p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/timeseries/info.py b/redis/commands/timeseries/info.py index 65f3baacd0..3a384dc049 100644 --- a/redis/commands/timeseries/info.py +++ b/redis/commands/timeseries/info.py @@ -80,3 +80,12 @@ def __init__(self, args): self.duplicate_policy = response["duplicatePolicy"] if type(self.duplicate_policy) == bytes: self.duplicate_policy = self.duplicate_policy.decode() + + def get(self, item): + try: + return self.__getitem__(item) + except AttributeError: + return None + + def __getitem__(self, item): + return getattr(self, item) diff --git a/redis/compat.py b/redis/compat.py index 738687f645..e478493467 100644 --- a/redis/compat.py +++ b/redis/compat.py @@ -2,8 +2,5 @@ try: from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] except ImportError: - from typing_extensions import ( # lgtm [py/unused-import] - Literal, - Protocol, - TypedDict, - ) + from typing_extensions import Literal # lgtm [py/unused-import] + from typing_extensions import Protocol, TypedDict diff --git a/redis/connection.py b/redis/connection.py old mode 100755 new mode 100644 index 126ea5db32..00d293a238 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,55 +1,39 @@ import copy -import errno -import io import os import socket +import ssl +import sys import threading import weakref +from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional +from typing import Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from redis.backoff import NoBackoff -from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider -from redis.exceptions import ( +from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser +from .backoff import NoBackoff +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider +from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.retry import Retry -from redis.utils import CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, str_if_bytes - -try: - import ssl - - ssl_available = True -except ImportError: - ssl_available = False - -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} - -if ssl_available: - if hasattr(ssl, "SSLWantReadError"): - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 - else: - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) +from .retry import Retry +from .utils import ( + CRYPTOGRAPHY_AVAILABLE, + HIREDIS_AVAILABLE, + HIREDIS_PACK_AVAILABLE, + SSL_AVAILABLE, + get_lib_version, + str_if_bytes, +) if HIREDIS_AVAILABLE: import hiredis @@ -59,465 +43,95 @@ SYM_CRLF = b"\r\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +DEFAULT_RESP_VERSION = 2 SENTINEL = object() -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class Encoder: - "Encode strings to bytes-like and decode bytes-like to strings" - - def __init__(self, encoding, encoding_errors, decode_responses): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value): - "Return a bytestring or bytes-like representation of the value" - if isinstance(value, (bytes, memoryview)): - return value - elif isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. Convert to a " - "bytes, string, int or float first." - ) - elif isinstance(value, (int, float)): - value = repr(value).encode() - elif not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = type(value).__name__ - raise DataError( - f"Invalid input of type: '{typename}'. " - f"Convert to a bytes, string, int or float first." - ) - if isinstance(value, str): - value = value.encode(self.encoding, self.encoding_errors) - return value - - def decode(self, value, force=False): - "Return a unicode string from the bytes-like representation" - if self.decode_responses or force: - if isinstance(value, memoryview): - value = value.tobytes() - if isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - return value - - -class BaseParser: - EXCEPTION_CLASSES = { - "ERR": { - "max number of clients reached": ConnectionError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments " - "for 'auth' command": AuthenticationWrongNumberOfArgsError, - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments " - "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def parse_error(self, response): - "Parse an error response" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - -class SocketBuffer: - def __init__(self, socket, socket_read_size, socket_timeout): - self._sock = socket - self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout - self._buffer = io.BytesIO() - # number of bytes written to the buffer from the socket - self.bytes_written = 0 - # number of bytes read from the buffer - self.bytes_read = 0 - - @property - def length(self): - return self.bytes_written - self.bytes_read - - def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True): - sock = self._sock - socket_read_size = self.socket_read_size - buf = self._buffer - buf.seek(self.bytes_written) - marker = 0 - custom_timeout = timeout is not SENTINEL - - try: - if custom_timeout: - sock.settimeout(timeout) - while True: - data = self._sock.recv(socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - self.bytes_written += data_length - marker += data_length - - if length is not None and length > marker: - continue - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - if custom_timeout: - sock.settimeout(self.socket_timeout) - - def can_read(self, timeout): - return bool(self.length) or self._read_from_socket( - timeout=timeout, raise_on_timeout=False - ) - - def read(self, length): - length = length + 2 # make sure to read the \r\n terminator - # make sure we've read enough data from the socket - if length > self.length: - self._read_from_socket(length - self.length) - - self._buffer.seek(self.bytes_read) - data = self._buffer.read(length) - self.bytes_read += len(data) - return data[:-2] - - def readline(self): - buf = self._buffer - buf.seek(self.bytes_read) - data = buf.readline() - while not data.endswith(SYM_CRLF): - # there's more data in the socket that we need - self._read_from_socket() - buf.seek(self.bytes_read) - data = buf.readline() - - self.bytes_read += len(data) - return data[:-2] - - def get_pos(self): - """ - Get current read position - """ - return self.bytes_read - - def rewind(self, pos): - """ - Rewind the buffer to a specific position, to re-start reading - """ - self.bytes_read = pos - - def purge(self): - """ - After a successful read, purge the read part of buffer - """ - unread = self.bytes_written - self.bytes_read - - # Only if we have read all of the buffer do we truncate, to - # reduce the amount of memory thrashing. This heuristic - # can be changed or removed later. - if unread > 0: - return - - if unread > 0: - # move unread data to the front - view = self._buffer.getbuffer() - view[:unread] = view[-unread:] - self._buffer.truncate(unread) - self.bytes_written = unread - self.bytes_read = 0 - self._buffer.seek(0) - - def close(self): - try: - self.bytes_written = self.bytes_read = 0 - self._buffer.close() - except Exception: - # issue #633 suggests the purge/close somehow raised a - # BadFileDescriptor error. Perhaps the client ran out of - # memory or something else? It's probably OK to ignore - # any error being raised from purge/close since we're - # removing the reference to the instance below. - pass - self._buffer = None - self._sock = None - - -class PythonParser(BaseParser): - "Plain Python parsing class" - - def __init__(self, socket_read_size): - self.socket_read_size = socket_read_size - self.encoder = None - self._sock = None - self._buffer = None - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection): - "Called when the socket connects" - self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) - self.encoder = connection.encoder - - def on_disconnect(self): - "Called when the socket disconnects" - self._sock = None - if self._buffer is not None: - self._buffer.close() - self._buffer = None - self.encoder = None - - def can_read(self, timeout): - return self._buffer and self._buffer.can_read(timeout) - - def read_response(self, disable_decoding=False): - pos = self._buffer.get_pos() - try: - result = self._read_response(disable_decoding=disable_decoding) - except BaseException: - self._buffer.rewind(pos) - raise - else: - self._buffer.purge() - return result - - def _read_response(self, disable_decoding=False): - raw = self._buffer.readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - byte, response = raw[:1], raw[1:] - - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - response = int(response) - # bulk response - elif byte == b"$": - length = int(response) - if length == -1: - return None - response = self._buffer.read(length) - # multi-bulk response - elif byte == b"*": - length = int(response) - if length == -1: - return None - response = [ - self._read_response(disable_decoding=disable_decoding) - for i in range(length) - ] - if isinstance(response, bytes) and disable_decoding is False: - response = self.encoder.decode(response) - return response +DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = _HiredisParser +else: + DefaultParser = _RESP2Parser -class HiredisParser(BaseParser): - "Parser class for connections using Hiredis" - def __init__(self, socket_read_size): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not installed") - self.socket_read_size = socket_read_size - self._buffer = bytearray(socket_read_size) +class HiredisRespSerializer: + def pack(self, *args): + """Pack a series of arguments into the Redis protocol""" + output = [] - def __del__(self): + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] try: - self.on_disconnect() - except Exception: - pass + output.append(hiredis.pack_command(args)) + except TypeError: + _, value, traceback = sys.exc_info() + raise DataError(value).with_traceback(traceback) - def on_connect(self, connection, **kwargs): - self._sock = connection._sock - self._socket_timeout = connection.socket_timeout - kwargs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - "errors": connection.encoder.encoding_errors, - } - - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - self._reader = hiredis.Reader(**kwargs) - self._next_response = False - - def on_disconnect(self): - self._sock = None - self._reader = None - self._next_response = False - - def can_read(self, timeout): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + return output - if self._next_response is False: - self._next_response = self._reader.gets() - if self._next_response is False: - return self.read_from_socket(timeout=timeout, raise_on_timeout=False) - return True - def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): - sock = self._sock - custom_timeout = timeout is not SENTINEL - try: - if custom_timeout: - sock.settimeout(timeout) - bufflen = self._sock.recv_into(self._buffer) - if bufflen == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - self._reader.feed(self._buffer, 0, bufflen) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - if custom_timeout: - sock.settimeout(self._socket_timeout) +class PythonRespSerializer: + def __init__(self, buffer_cutoff, encode) -> None: + self._buffer_cutoff = buffer_cutoff + self.encode = encode - def read_response(self, disable_decoding=False): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - # _next_response might be cached from a can_read() call - if self._next_response is not False: - response = self._next_response - self._next_response = False - return response + def pack(self, *args): + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) - while response is False: - self.read_from_socket() - if disable_decoding: - response = self._reader.gets(False) + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF else: - response = self._reader.gets() - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - - -if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser -else: - DefaultParser = PythonParser + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output -class Connection: - "Manages TCP communication to and from a Redis server" +class AbstractConnection: + "Manages communication to and from a Redis server" def __init__( self, - host="localhost", - port=6379, db=0, password=None, socket_timeout=None, socket_connect_timeout=None, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, retry_on_timeout=False, retry_on_error=SENTINEL, encoding="utf-8", @@ -527,10 +141,14 @@ def __init__( socket_read_size=65536, health_check_interval=0, client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), username=None, retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer=None, ): """ Initialize a new Connection. @@ -547,18 +165,17 @@ def __init__( "2. 'credential_provider'" ) self.pid = os.getpid() - self.host = host - self.port = int(port) self.db = db self.client_name = client_name + self.lib_name = lib_name + self.lib_version = lib_version self.credential_provider = credential_provider self.password = password self.username = username self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout - self.socket_keepalive = socket_keepalive - self.socket_keepalive_options = socket_keepalive_options or {} - self.socket_type = socket_type + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -585,16 +202,26 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 + try: + p = int(protocol) + except TypeError: + p = DEFAULT_RESP_VERSION + except ValueError: + raise ConnectionError("protocol must be an integer") + finally: + if p < 2 or p > 3: + raise ConnectionError("protocol must be either 2 or 3") + # p = DEFAULT_RESP_VERSION + self.protocol = p + self._command_packer = self._construct_command_packer(command_packer) def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) return f"{self.__class__.__name__}<{repr_args}>" + @abstractmethod def repr_pieces(self): - pieces = [("host", self.host), ("port", self.port), ("db", self.db)] - if self.client_name: - pieces.append(("client_name", self.client_name)) - return pieces + pass def __del__(self): try: @@ -602,6 +229,14 @@ def __del__(self): except Exception: pass + def _construct_command_packer(self, packer): + if packer is not None: + return packer + elif HIREDIS_PACK_AVAILABLE: + return HiredisRespSerializer() + else: + return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) + def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -649,80 +284,24 @@ def connect(self): if callback: callback(self) + @abstractmethod def _connect(self): - "Create a TCP socket connection" - # we want to mimic what socket.create_connection does to support - # ipv4/ipv6, but we want to set options prior to calling - # socket.connect() - err = None - for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM - ): - family, socktype, proto, canonname, socket_address = res - sock = None - try: - sock = socket.socket(family, socktype, proto) - # TCP_NODELAY - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - # TCP_KEEPALIVE - if self.socket_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - for k, v in self.socket_keepalive_options.items(): - sock.setsockopt(socket.IPPROTO_TCP, k, v) - - # set the socket_connect_timeout before we connect - sock.settimeout(self.socket_connect_timeout) - - # connect - sock.connect(socket_address) - - # set the socket_timeout now that we're connected - sock.settimeout(self.socket_timeout) - return sock - - except OSError as _: - err = _ - if sock is not None: - sock.close() - - if err is not None: - raise err - raise OSError("socket.getaddrinfo returned an empty list") + pass + @abstractmethod def _host_error(self): - try: - host_error = f"{self.host}:{self.port}" - except AttributeError: - host_error = "connection" - - return host_error + pass + @abstractmethod def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if len(exception.args) == 1: - try: - return f"Error connecting to {host_error}. \ - {exception.args[0]}." - except AttributeError: - return f"Connection Error: {exception.args[0]}" - else: - try: - return ( - f"Error {exception.args[0]} connecting to " - f"{host_error}. {exception.args[1]}." - ) - except AttributeError: - return f"Connection Error: {exception.args[0]}" + pass def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) + parser = self._parser + auth_args = None # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( @@ -730,6 +309,24 @@ def on_connect(self): or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() + + # if resp version is specified and we have auth args, + # we need to send them via HELLO + if auth_args and self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + if len(auth_args) == 1: + auth_args = ["default", auth_args[0]] + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = self.read_response() + # if response.get(b"proto") != self.protocol and response.get( + # "proto" + # ) != self.protocol: + # raise ConnectionError("Invalid RESP version") + elif auth_args: # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -747,12 +344,42 @@ def on_connect(self): if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + elif self.protocol not in [2, "2"]: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + # update cluster exception classes + self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES + self._parser.on_connect(self) + self.send_command("HELLO", self.protocol) + response = self.read_response() + if ( + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol + ): + raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: self.send_command("CLIENT", "SETNAME", self.client_name) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Error setting client name") + try: + # set the library name and version + if self.lib_name: + self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) + self.read_response() + except ResponseError: + pass + + try: + if self.lib_version: + self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) + self.read_response() + except ResponseError: + pass + # if a database is specified, switch to it if self.db: self.send_command("SELECT", self.db) @@ -762,20 +389,22 @@ def on_connect(self): def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() - if self._sock is None: + + conn_sock = self._sock + self._sock = None + if conn_sock is None: return if os.getpid() == self.pid: try: - self._sock.shutdown(socket.SHUT_RDWR) + conn_sock.shutdown(socket.SHUT_RDWR) except OSError: pass try: - self._sock.close() + conn_sock.close() except OSError: pass - self._sock = None def _send_ping(self): """Send PING, expect PONG in return""" @@ -815,14 +444,19 @@ def send_packed_command(self, command, check_health=True): errno = e.args[0] errmsg = e.args[1] raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. self.disconnect() raise def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" self.send_packed_command( - self.pack_command(*args), check_health=kwargs.get("check_health", True) + self._command_packer.pack(*args), + check_health=kwargs.get("check_health", True), ) def can_read(self, timeout=0): @@ -839,23 +473,40 @@ def can_read(self, timeout=0): self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - def read_response(self, disable_decoding=False): + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): """Read the response from a previously sent command""" host_error = self._host_error() try: - response = self._parser.read_response(disable_decoding=disable_decoding) + if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: + response = self._parser.read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise ConnectionError( f"Error while reading from {host_error}" f" : {e.args}" ) - except Exception: - self.disconnect() + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() raise if self.health_check_interval: @@ -867,48 +518,7 @@ def read_response(self, disable_decoding=False): def pack_command(self, *args): """Pack a series of arguments into the Redis protocol""" - output = [] - # the client might have included 1 or more literal arguments in - # the command name, e.g., 'CONFIG GET'. The Redis server expects these - # arguments to be sent separately, so split the first argument - # manually. These arguments should be bytestrings so that they are - # not encoded. - if isinstance(args[0], str): - args = tuple(args[0].encode().split()) + args[1:] - elif b" " in args[0]: - args = tuple(args[0].split()) + args[1:] - - buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) - - buffer_cutoff = self._buffer_cutoff - for arg in map(self.encoder.encode, args): - # to avoid large string mallocs, chunk the command into the - # output list if we're sending large values or memoryviews - arg_length = len(arg) - if ( - len(buff) > buffer_cutoff - or arg_length > buffer_cutoff - or isinstance(arg, memoryview) - ): - buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) - ) - output.append(buff) - output.append(arg) - buff = SYM_CRLF - else: - buff = SYM_EMPTY.join( - ( - buff, - SYM_DOLLAR, - str(arg_length).encode(), - SYM_CRLF, - arg, - SYM_CRLF, - ) - ) - output.append(buff) - return output + return self._command_packer.pack(*args) def pack_commands(self, commands): """Pack multiple commands into the Redis protocol""" @@ -918,14 +528,15 @@ def pack_commands(self, commands): buffer_cutoff = self._buffer_cutoff for cmd in commands: - for chunk in self.pack_command(*cmd): + for chunk in self._command_packer.pack(*cmd): chunklen = len(chunk) if ( buffer_length > buffer_cutoff or chunklen > buffer_cutoff or isinstance(chunk, memoryview) ): - output.append(SYM_EMPTY.join(pieces)) + if pieces: + output.append(SYM_EMPTY.join(pieces)) buffer_length = 0 pieces = [] @@ -940,6 +551,97 @@ def pack_commands(self, commands): return output +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connect(self): + "Create a TCP socket connection" + # we want to mimic what socket.create_connection does to support + # ipv4/ipv6, but we want to set options prior to calling + # socket.connect() + err = None + for res in socket.getaddrinfo( + self.host, self.port, self.socket_type, socket.SOCK_STREAM + ): + family, socktype, proto, canonname, socket_address = res + sock = None + try: + sock = socket.socket(family, socktype, proto) + # TCP_NODELAY + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.IPPROTO_TCP, k, v) + + # set the socket_connect_timeout before we connect + sock.settimeout(self.socket_connect_timeout) + + # connect + sock.connect(socket_address) + + # set the socket_timeout now that we're connected + sock.settimeout(self.socket_timeout) + return sock + + except OSError as _: + err = _ + if sock is not None: + sock.close() + + if err is not None: + raise err + raise OSError("socket.getaddrinfo returned an empty list") + + def _host_error(self): + return f"{self.host}:{self.port}" + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if len(exception.args) == 1: + try: + return f"Error connecting to {host_error}. \ + {exception.args[0]}." + except AttributeError: + return f"Connection Error: {exception.args[0]}" + else: + try: + return ( + f"Error {exception.args[0]} connecting to " + f"{host_error}. {exception.args[1]}." + ) + except AttributeError: + return f"Connection Error: {exception.args[0]}" + + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). This class extends the Connection class, adding SSL functionality, and making @@ -982,11 +684,9 @@ def __init__( Raises: RedisError """ # noqa - if not ssl_available: + if not SSL_AVAILABLE: raise RedisError("Python wasn't built with SSL support") - super().__init__(**kwargs) - self.keyfile = ssl_keyfile self.certfile = ssl_certfile if ssl_cert_reqs is None: @@ -1012,6 +712,7 @@ def __init__( self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled self.ssl_ocsp_context = ssl_ocsp_context self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert + super().__init__(**kwargs) def _connect(self): "Wrap the socket with SSL support" @@ -1081,75 +782,13 @@ def _connect(self): return sslsock -class UnixDomainSocketConnection(Connection): - def __init__( - self, - path="", - db=0, - username=None, - password=None, - socket_timeout=None, - encoding="utf-8", - encoding_errors="strict", - decode_responses=False, - retry_on_timeout=False, - retry_on_error=SENTINEL, - parser_class=DefaultParser, - socket_read_size=65536, - health_check_interval=0, - client_name=None, - retry=None, - redis_connect_func=None, - credential_provider: Optional[CredentialProvider] = None, - ): - """ - Initialize a new UnixDomainSocketConnection. - To specify a retry policy for specific errors, first set - `retry_on_error` to a list of the error/s to retry on, then set - `retry` to a valid `Retry` object. - To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. - """ - if (username or password) and credential_provider is not None: - raise DataError( - "'username' and 'password' cannot be passed along with 'credential_" - "provider'. Please provide only one of the following arguments: \n" - "1. 'password' and (optional) 'username'\n" - "2. 'credential_provider'" - ) - self.pid = os.getpid() +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, path="", socket_timeout=None, **kwargs): self.path = path - self.db = db - self.client_name = client_name - self.credential_provider = credential_provider - self.password = password - self.username = username self.socket_timeout = socket_timeout - self.retry_on_timeout = retry_on_timeout - if retry_on_error is SENTINEL: - retry_on_error = [] - if retry_on_timeout: - # Add TimeoutError to the errors list to retry on - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if self.retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) - else: - self.retry = Retry(NoBackoff(), 0) - self.health_check_interval = health_check_interval - self.next_health_check = 0 - self.redis_connect_func = redis_connect_func - self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self._sock = None - self._socket_read_size = socket_read_size - self.set_parser(parser_class) - self._connect_callbacks = [] - self._buffer_cutoff = 6000 + super().__init__(**kwargs) def repr_pieces(self): pieces = [("path", self.path), ("db", self.db)] @@ -1160,19 +799,26 @@ def repr_pieces(self): def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(self.socket_timeout) + sock.settimeout(self.socket_connect_timeout) sock.connect(self.path) + sock.settimeout(self.socket_timeout) return sock + def _host_error(self): + return self.path + def _error_message(self, exception): # args for socket.error can either be (errno, "message") # or just "message" + host_error = self._host_error() if len(exception.args) == 1: - return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) else: return ( f"Error {exception.args[0]} connecting to unix socket: " - f"{self.path}. {exception.args[1]}." + f"{host_error}. {exception.args[1]}." ) diff --git a/redis/exceptions.py b/redis/exceptions.py index 8a8bf423eb..7cf15a7d07 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -49,6 +49,18 @@ class NoScriptError(ResponseError): pass +class OutOfMemoryError(ResponseError): + """ + Indicates the database is full. Can only occur when either: + * Redis maxmemory-policy=noeviction + * Redis maxmemory-policy=volatile* and there are no evictable keys + + For more information see `Memory optimization in Redis `_. # noqa + """ + + pass + + class ExecAbortError(ResponseError): pass @@ -131,6 +143,7 @@ class AskError(ResponseError): pertain to this hash slot, but only if the key in question exists, otherwise the query is forwarded using a -ASK redirection to the node that is target of the migration. + src node: MIGRATING to dst node get > ASK error ask dst node > ASKING command diff --git a/redis/ocsp.py b/redis/ocsp.py index ab8a35a33d..b0420b4711 100644 --- a/redis/ocsp.py +++ b/redis/ocsp.py @@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.hashes import SHA1, Hash from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from cryptography.x509 import ocsp - from redis.exceptions import AuthorizationError, ConnectionError diff --git a/redis/py.typed b/redis/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/sentinel.py b/redis/sentinel.py index d70b7142b5..0ba179b9ca 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -1,5 +1,6 @@ import random import weakref +from typing import Optional from redis.client import Redis from redis.commands import SentinelCommands @@ -53,9 +54,14 @@ def _connect_retry(self): def connect(self): return self.retry.call_with_retry(self._connect_retry, lambda error: None) - def read_response(self, disable_decoding=False): + def read_response( + self, disable_decoding=False, *, disconnect_on_error: Optional[bool] = False + ): try: - return super().read_response(disable_decoding=disable_decoding) + return super().read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + ) except ReadOnlyError: if self.connection_pool.is_master: # When talking to a master, a ReadOnlyError when likely @@ -72,6 +78,54 @@ class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): pass +class SentinelConnectionPoolProxy: + def __init__( + self, + connection_pool, + is_master, + check_connection, + service_name, + sentinel_manager, + ): + self.connection_pool_ref = weakref.ref(connection_pool) + self.is_master = is_master + self.check_connection = check_connection + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.reset() + + def reset(self): + self.master_address = None + self.slave_rr_counter = None + + def get_master_address(self): + master_address = self.sentinel_manager.discover_master(self.service_name) + if self.is_master and self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + connection_pool = self.connection_pool_ref() + if connection_pool is not None: + connection_pool.disconnect(inuse_connections=False) + return master_address + + def rotate_slaves(self): + slaves = self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + class SentinelConnectionPool(ConnectionPool): """ Sentinel backed connection pool. @@ -89,8 +143,15 @@ def __init__(self, service_name, sentinel_manager, **kwargs): ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) + self.proxy = SentinelConnectionPoolProxy( + connection_pool=self, + is_master=self.is_master, + check_connection=self.check_connection, + service_name=service_name, + sentinel_manager=sentinel_manager, + ) super().__init__(**kwargs) - self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.connection_kwargs["connection_pool"] = self.proxy self.service_name = service_name self.sentinel_manager = sentinel_manager @@ -100,8 +161,11 @@ def __repr__(self): def reset(self): super().reset() - self.master_address = None - self.slave_rr_counter = None + self.proxy.reset() + + @property + def master_address(self): + return self.proxy.master_address def owns_connection(self, connection): check = not self.is_master or ( @@ -111,31 +175,11 @@ def owns_connection(self, connection): return check and parent.owns_connection(connection) def get_master_address(self): - master_address = self.sentinel_manager.discover_master(self.service_name) - if self.is_master: - if self.master_address != master_address: - self.master_address = master_address - # disconnect any idle connections so that they reconnect - # to the new master the next time that they are used. - self.disconnect(inuse_connections=False) - return master_address + return self.proxy.get_master_address() def rotate_slaves(self): "Round-robin slave balancer" - slaves = self.sentinel_manager.discover_slaves(self.service_name) - if slaves: - if self.slave_rr_counter is None: - self.slave_rr_counter = random.randint(0, len(slaves) - 1) - for _ in range(len(slaves)): - self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) - slave = slaves[self.slave_rr_counter] - yield slave - # Fallback to the master connection - try: - yield self.get_master_address() - except MasterNotFoundError: - pass - raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + return self.proxy.rotate_slaves() class Sentinel(SentinelCommands): @@ -230,10 +274,12 @@ def discover_master(self, service_name): Returns a pair (address, port) or raises MasterNotFoundError if no master is found. """ + collected_errors = list() for sentinel_no, sentinel in enumerate(self.sentinels): try: masters = sentinel.sentinel_masters() - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + collected_errors.append(f"{sentinel} - {e!r}") continue state = masters.get(service_name) if state and self.check_master_state(state, service_name): @@ -243,7 +289,11 @@ def discover_master(self, service_name): self.sentinels[0], ) return state["ip"], state["port"] - raise MasterNotFoundError(f"No master found for {service_name!r}") + + error_info = "" + if len(collected_errors) > 0: + error_info = f" : {', '.join(collected_errors)}" + raise MasterNotFoundError(f"No master found for {service_name!r}{error_info}") def filter_slaves(self, slaves): "Remove slaves that are in an ODOWN or SDOWN state" diff --git a/redis/typing.py b/redis/typing.py index 8504c7de0c..56a1e99ba7 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,14 +1,23 @@ # from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Iterable, + Mapping, + Type, + TypeVar, + Union, +) from redis.compat import Protocol if TYPE_CHECKING: + from redis._parsers import Encoder from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool - from redis.asyncio.connection import Encoder as AsyncEncoder - from redis.connection import ConnectionPool, Encoder + from redis.connection import ConnectionPool Number = Union[int, float] @@ -39,6 +48,8 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] @@ -47,8 +58,8 @@ def execute_command(self, *args, **options): ... -class ClusterCommandsProtocol(CommandsProtocol): - encoder: Union["AsyncEncoder", "Encoder"] +class ClusterCommandsProtocol(CommandsProtocol, Protocol): + encoder: "Encoder" def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... diff --git a/redis/utils.py b/redis/utils.py index 693d4e64b5..01fdfed7a2 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,3 +1,5 @@ +import logging +import sys from contextlib import contextmanager from functools import wraps from typing import Any, Dict, Mapping, Union @@ -7,8 +9,17 @@ # Only support Hiredis >= 1.0: HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") + HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") except ImportError: HIREDIS_AVAILABLE = False + HIREDIS_PACK_AVAILABLE = False + +try: + import ssl # noqa + + SSL_AVAILABLE = True +except ImportError: + SSL_AVAILABLE = False try: import cryptography # noqa @@ -17,6 +28,11 @@ except ImportError: CRYPTOGRAPHY_AVAILABLE = False +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + def from_url(url, **kwargs): """ @@ -108,3 +124,24 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def _set_info_logger(): + """ + Set up a logger that log info logs to stdout. + (This is used by the default push response handler) + """ + if "push_response" not in logging.root.manager.loggerDict.keys(): + logger = logging.getLogger("push_response") + logger.setLevel(logging.INFO) + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + logger.addHandler(handler) + + +def get_lib_version(): + try: + libver = metadata.version("redis") + except metadata.PackageNotFoundError: + libver = "99.99.99" + return libver diff --git a/setup.py b/setup.py index 16a9156641..475e3565fa 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,11 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="4.4.2", + version="5.0.0", packages=find_packages( include=[ "redis", + "redis._parsers", "redis.asyncio", "redis.commands", "redis.commands.bf", @@ -19,8 +20,11 @@ "redis.commands.search", "redis.commands.timeseries", "redis.commands.graph", + "redis.parsers", ] ), + package_data={"redis": ["py.typed"]}, + include_package_data=True, url="https://github.com/redis/redis-py", project_urls={ "Documentation": "https://redis.readthedocs.io/en/latest/", @@ -34,7 +38,7 @@ install_requires=[ 'importlib-metadata >= 1.0; python_version < "3.8"', 'typing-extensions; python_version<"3.8"', - "async-timeout>=4.0.2", + 'async-timeout>=4.0.2; python_full_version<="3.11.2"', ], classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/tasks.py b/tasks.py index 64b3aef80f..c60fa2791e 100644 --- a/tasks.py +++ b/tasks.py @@ -1,69 +1,81 @@ +# https://github.com/pyinvoke/invoke/issues/833 +import inspect import os import shutil from invoke import run, task -with open("tox.ini") as fp: - lines = fp.read().split("\n") - dockers = [line.split("=")[1].strip() for line in lines if line.find("name") != -1] +if not hasattr(inspect, "getargspec"): + inspect.getargspec = inspect.getfullargspec @task def devenv(c): - """Builds a development environment: downloads, and starts all dockers - specified in the tox.ini file. - """ + """Brings up the test environment, by wrapping docker compose.""" clean(c) - cmd = "tox -e devenv" - for d in dockers: - cmd += f" --docker-dont-stop={d}" + cmd = "docker-compose --profile all up -d" run(cmd) @task def build_docs(c): """Generates the sphinx documentation.""" - run("tox -e docs") + run("pip install -r docs/requirements.txt") + run("make -C docs html") @task def linters(c): """Run code linters""" - run("tox -e linters") + run("flake8 tests redis") + run("black --target-version py37 --check --diff tests redis") + run("isort --check-only --diff tests redis") + run("vulture redis whitelist.py --min-confidence 80") + run("flynt --fail-on-change --dry-run tests redis") @task def all_tests(c): - """Run all linters, and tests in redis-py. This assumes you have all - the python versions specified in the tox.ini file. - """ + """Run all linters, and tests in redis-py.""" linters(c) tests(c) @task -def tests(c): +def tests(c, uvloop=False, protocol=2): """Run the redis-py test suite against the current python, with and without hiredis. """ print("Starting Redis tests") - run("tox -e '{standalone,cluster}'-'{plain,hiredis}'") + standalone_tests(c, uvloop=uvloop, protocol=protocol) + cluster_tests(c, uvloop=uvloop, protocol=protocol) @task -def standalone_tests(c): - """Run all Redis tests against the current python, - with and without hiredis.""" - print("Starting Redis tests") - run("tox -e standalone-'{plain,hiredis,ocsp}'") +def standalone_tests(c, uvloop=False, protocol=2): + """Run tests against a standalone redis instance""" + if uvloop: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --uvloop --junit-xml=standalone-uvloop-results.xml" + ) + else: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --junit-xml=standalone-results.xml" + ) @task -def cluster_tests(c): - """Run all Redis Cluster tests against the current python, - with and without hiredis.""" - print("Starting RedisCluster tests") - run("tox -e cluster-'{plain,hiredis}'") +def cluster_tests(c, uvloop=False, protocol=2): + """Run tests against a redis cluster""" + cluster_url = "redis://localhost:16379/0" + if uvloop: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --junit-xml=cluster-uvloop-results.xml --uvloop" + ) + else: + run( + f"pytest --protocol={protocol} --cov=./ --cov-report=xml:coverage_clusteclient.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={cluster_url} --junit-xml=cluster-results.xml" + ) @task @@ -73,7 +85,7 @@ def clean(c): shutil.rmtree("build") if os.path.isdir("dist"): shutil.rmtree("dist") - run(f"docker rm -f {' '.join(dockers)}") + run("docker-compose --profile all rm -s -f") @task diff --git a/tests/conftest.py b/tests/conftest.py index 27dcc741a7..16f3fbb9db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,18 +6,17 @@ from urllib.parse import urlparse import pytest -from packaging.version import Version - import redis +from packaging.version import Version from redis.backoff import NoBackoff from redis.connection import parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/9" -default_redismod_url = "redis://localhost:36379" -default_redis_unstable_url = "redis://localhost:6378" +default_redis_url = "redis://localhost:6379/0" +default_protocol = "2" +default_redismod_url = "redis://localhost:6379" # default ssl client ignores verification for the purpose of testing default_redis_ssl_url = "rediss://localhost:6666" @@ -73,6 +72,7 @@ def format_usage(self): def pytest_addoption(parser): + parser.addoption( "--redis-url", default=default_redis_url, @@ -81,14 +81,11 @@ def pytest_addoption(parser): ) parser.addoption( - "--redismod-url", - default=default_redismod_url, + "--protocol", + default=default_protocol, action="store", - help="Connection string to redis server" - " with loaded modules," - " defaults to `%(default)s`", + help="Protocol version, defaults to `%(default)s`", ) - parser.addoption( "--redis-ssl-url", default=default_redis_ssl_url, @@ -105,13 +102,6 @@ def pytest_addoption(parser): " defaults to `%(default)s`", ) - parser.addoption( - "--redis-unstable-url", - default=default_redis_unstable_url, - action="store", - help="Redis unstable (latest version) connection string " - "defaults to %(default)s`", - ) parser.addoption( "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" ) @@ -141,6 +131,7 @@ def pytest_sessionstart(session): enterprise = info["enterprise"] except redis.ConnectionError: # provide optimistic defaults + info = {} version = "10.0.0" arch_bits = 64 cluster_enabled = False @@ -152,14 +143,10 @@ def pytest_sessionstart(session): # store REDIS_INFO in config so that it is available from "condition strings" session.config.REDIS_INFO = REDIS_INFO - # module info, if the second redis is running + # module info try: - redismod_url = session.config.getoption("--redismod-url") - info = _get_info(redismod_url) REDIS_INFO["modules"] = info["modules"] - except redis.exceptions.ConnectionError: - pass - except KeyError: + except (KeyError, redis.exceptions.ConnectionError): pass if cluster_enabled: @@ -289,6 +276,9 @@ def _get_client( redis_url = request.config.getoption("--redis-url") else: redis_url = from_url + if "protocol" not in redis_url: + kwargs["protocol"] = request.config.getoption("--protocol") + cluster_mode = REDIS_INFO["cluster_enabled"] if not cluster_mode: url_options = parse_url(redis_url) @@ -332,20 +322,15 @@ def cluster_teardown(client, flushdb): client.disconnect_connection_pools() -# specifically set to the zero database, because creating -# an index on db != 0 raises a ResponseError in redis @pytest.fixture() -def modclient(request, **kwargs): - rmurl = request.config.getoption("--redismod-url") - with _get_client( - redis.Redis, request, from_url=rmurl, decode_responses=True, **kwargs - ) as client: +def r(request): + with _get_client(redis.Redis, request) as client: yield client @pytest.fixture() -def r(request): - with _get_client(redis.Redis, request) as client: +def decoded_r(request): + with _get_client(redis.Redis, request, decode_responses=True) as client: yield client @@ -385,7 +370,7 @@ def mock_cluster_resp_ok(request, **kwargs): @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, "2") + return _gen_cluster_mock_resp(r, 2) @pytest.fixture() @@ -441,16 +426,7 @@ def mock_cluster_resp_slaves(request, **kwargs): def master_host(request): url = request.config.getoption("--redis-url") parts = urlparse(url) - yield parts.hostname, parts.port - - -@pytest.fixture() -def unstable_r(request): - url = request.config.getoption("--redis-unstable-url") - with _get_client( - redis.Redis, request, from_url=url, decode_responses=True - ) as client: - yield client + return parts.hostname, (parts.port or 6379) def wait_for_command(client, monitor, command, key=None): @@ -472,3 +448,34 @@ def wait_for_command(client, monitor, command, key=None): return monitor_response if key in monitor_response["command"]: return None + + +def is_resp2_connection(r): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): + protocol = r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.cluster.AbstractRedisCluster): + protocol = r.nodes_manager.connection_kwargs.get("protocol") + return protocol in ["2", 2, None] + + +def get_protocol_version(r): + if isinstance(r, redis.Redis) or isinstance(r, redis.asyncio.Redis): + return r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.cluster.AbstractRedisCluster): + return r.nodes_manager.connection_kwargs.get("protocol") + + +def assert_resp_response(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response == resp2_expected + else: + assert response == resp3_expected + + +def assert_resp_response_in(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol in [2, "2", None]: + assert response in resp2_expected + else: + assert response in resp3_expected diff --git a/tests/ssl_utils.py b/tests/ssl_utils.py new file mode 100644 index 0000000000..ab9c2e8944 --- /dev/null +++ b/tests/ssl_utils.py @@ -0,0 +1,14 @@ +import os + + +def get_ssl_filename(name): + root = os.path.join(os.path.dirname(__file__), "..") + cert_dir = os.path.abspath(os.path.join(root, "dockers", "stunnel", "keys")) + if not os.path.isdir(cert_dir): # github actions package validation case + cert_dir = os.path.abspath( + os.path.join(root, "..", "dockers", "stunnel", "keys") + ) + if not os.path.isdir(cert_dir): + raise IOError(f"No SSL certificates found. They should be in {cert_dir}") + + return os.path.join(cert_dir, name) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6982cc840a..c837f284f7 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,22 +1,17 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager from typing import Union -from urllib.parse import urlparse import pytest import pytest_asyncio -from packaging.version import Version - import redis.asyncio as redis +from packaging.version import Version +from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.asyncio.client import Monitor -from redis.asyncio.connection import ( - HIREDIS_AVAILABLE, - HiredisParser, - PythonParser, - parse_url, -) +from redis.asyncio.connection import parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import REDIS_INFO from .compat import mock @@ -32,14 +27,14 @@ async def _get_info(redis_url): @pytest_asyncio.fixture( params=[ pytest.param( - (True, PythonParser), + (True, _AsyncRESP2Parser), marks=pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', reason="cluster mode enabled" ), ), - (False, PythonParser), + (False, _AsyncRESP2Parser), pytest.param( - (True, HiredisParser), + (True, _AsyncHiredisParser), marks=[ pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', @@ -51,7 +46,7 @@ async def _get_info(redis_url): ], ), pytest.param( - (False, HiredisParser), + (False, _AsyncHiredisParser), marks=pytest.mark.skipif( not HIREDIS_AVAILABLE, reason="hiredis is not installed" ), @@ -74,8 +69,12 @@ async def client_factory( url: str = request.config.getoption("--redis-url"), cls=redis.Redis, flushdb=True, + protocol=request.config.getoption("--protocol"), **kwargs, ): + if "protocol" not in url: + kwargs["protocol"] = request.config.getoption("--protocol") + cluster_mode = REDIS_INFO["cluster_enabled"] if not cluster_mode: single = kwargs.pop("single_connection_client", False) or single_connection @@ -134,10 +133,8 @@ async def r2(create_redis): @pytest_asyncio.fixture() -async def modclient(request, create_redis): - return await create_redis( - url=request.config.getoption("--redismod-url"), decode_responses=True - ) +async def decoded_r(create_redis): + return await create_redis(decode_responses=True) def _gen_cluster_mock_resp(r, response): @@ -157,7 +154,7 @@ async def mock_cluster_resp_ok(create_redis, **kwargs): @pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, "2") + return _gen_cluster_mock_resp(r, 2) @pytest_asyncio.fixture() @@ -209,13 +206,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): return _gen_cluster_mock_resp(r, response) -@pytest_asyncio.fixture(scope="session") -def master_host(request): - url = request.config.getoption("--redis-url") - parts = urlparse(url) - return parts.hostname - - async def wait_for_command( client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None ): @@ -263,3 +253,15 @@ async def __aexit__(self, exc_type, exc_inst, tb): def asynccontextmanager(func): return _asynccontextmanager(func) + + +# helpers to get the connection arguments for this run +@pytest.fixture() +def redis_url(request): + return request.config.getoption("--redis-url") + + +@pytest.fixture() +def connect_args(request): + url = request.config.getoption("--redis-url") + return parse_url(url) diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index 9f4a805c4c..d0a25e5625 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -1,94 +1,99 @@ from math import inf import pytest - import redis.asyncio as redis from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_ifmodversion_lt, +) def intlist(obj): return [int(v) for v in obj] -# @pytest.fixture -# async def client(modclient): -# assert isinstance(modawait modclient.bf(), redis.commands.bf.BFBloom) -# assert isinstance(modawait modclient.cf(), redis.commands.bf.CFBloom) -# assert isinstance(modawait modclient.cms(), redis.commands.bf.CMSBloom) -# assert isinstance(modawait modclient.tdigest(), redis.commands.bf.TDigestBloom) -# assert isinstance(modawait modclient.topk(), redis.commands.bf.TOPKBloom) - -# modawait modclient.flushdb() -# return modclient - - @pytest.mark.redismod -async def test_create(modclient: redis.Redis): +async def test_create(decoded_r: redis.Redis): """Test CREATE/RESERVE calls""" - assert await modclient.bf().create("bloom", 0.01, 1000) - assert await modclient.bf().create("bloom_e", 0.01, 1000, expansion=1) - assert await modclient.bf().create("bloom_ns", 0.01, 1000, noScale=True) - assert await modclient.cf().create("cuckoo", 1000) - assert await modclient.cf().create("cuckoo_e", 1000, expansion=1) - assert await modclient.cf().create("cuckoo_bs", 1000, bucket_size=4) - assert await modclient.cf().create("cuckoo_mi", 1000, max_iterations=10) - assert await modclient.cms().initbydim("cmsDim", 100, 5) - assert await modclient.cms().initbyprob("cmsProb", 0.01, 0.01) - assert await modclient.topk().reserve("topk", 5, 100, 5, 0.9) + assert await decoded_r.bf().create("bloom", 0.01, 1000) + assert await decoded_r.bf().create("bloom_e", 0.01, 1000, expansion=1) + assert await decoded_r.bf().create("bloom_ns", 0.01, 1000, noScale=True) + assert await decoded_r.cf().create("cuckoo", 1000) + assert await decoded_r.cf().create("cuckoo_e", 1000, expansion=1) + assert await decoded_r.cf().create("cuckoo_bs", 1000, bucket_size=4) + assert await decoded_r.cf().create("cuckoo_mi", 1000, max_iterations=10) + assert await decoded_r.cms().initbydim("cmsDim", 100, 5) + assert await decoded_r.cms().initbyprob("cmsProb", 0.01, 0.01) + assert await decoded_r.topk().reserve("topk", 5, 100, 5, 0.9) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_create(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_create(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) -# region Test Bloom Filter @pytest.mark.redismod -async def test_bf_add(modclient: redis.Redis): - assert await modclient.bf().create("bloom", 0.01, 1000) - assert 1 == await modclient.bf().add("bloom", "foo") - assert 0 == await modclient.bf().add("bloom", "foo") - assert [0] == intlist(await modclient.bf().madd("bloom", "foo")) - assert [0, 1] == await modclient.bf().madd("bloom", "foo", "bar") - assert [0, 0, 1] == await modclient.bf().madd("bloom", "foo", "bar", "baz") - assert 1 == await modclient.bf().exists("bloom", "foo") - assert 0 == await modclient.bf().exists("bloom", "noexist") - assert [1, 0] == intlist(await modclient.bf().mexists("bloom", "foo", "noexist")) +async def test_bf_add(decoded_r: redis.Redis): + assert await decoded_r.bf().create("bloom", 0.01, 1000) + assert 1 == await decoded_r.bf().add("bloom", "foo") + assert 0 == await decoded_r.bf().add("bloom", "foo") + assert [0] == intlist(await decoded_r.bf().madd("bloom", "foo")) + assert [0, 1] == await decoded_r.bf().madd("bloom", "foo", "bar") + assert [0, 0, 1] == await decoded_r.bf().madd("bloom", "foo", "bar", "baz") + assert 1 == await decoded_r.bf().exists("bloom", "foo") + assert 0 == await decoded_r.bf().exists("bloom", "noexist") + assert [1, 0] == intlist(await decoded_r.bf().mexists("bloom", "foo", "noexist")) @pytest.mark.redismod -async def test_bf_insert(modclient: redis.Redis): - assert await modclient.bf().create("bloom", 0.01, 1000) - assert [1] == intlist(await modclient.bf().insert("bloom", ["foo"])) - assert [0, 1] == intlist(await modclient.bf().insert("bloom", ["foo", "bar"])) - assert [1] == intlist(await modclient.bf().insert("captest", ["foo"], capacity=10)) - assert [1] == intlist(await modclient.bf().insert("errtest", ["foo"], error=0.01)) - assert 1 == await modclient.bf().exists("bloom", "foo") - assert 0 == await modclient.bf().exists("bloom", "noexist") - assert [1, 0] == intlist(await modclient.bf().mexists("bloom", "foo", "noexist")) - info = await modclient.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum +async def test_bf_insert(decoded_r: redis.Redis): + assert await decoded_r.bf().create("bloom", 0.01, 1000) + assert [1] == intlist(await decoded_r.bf().insert("bloom", ["foo"])) + assert [0, 1] == intlist(await decoded_r.bf().insert("bloom", ["foo", "bar"])) + assert [1] == intlist(await decoded_r.bf().insert("captest", ["foo"], capacity=10)) + assert [1] == intlist(await decoded_r.bf().insert("errtest", ["foo"], error=0.01)) + assert 1 == await decoded_r.bf().exists("bloom", "foo") + assert 0 == await decoded_r.bf().exists("bloom", "noexist") + assert [1, 0] == intlist(await decoded_r.bf().mexists("bloom", "foo", "noexist")) + info = await decoded_r.bf().info("bloom") + assert_resp_response( + decoded_r, + 2, + info.get("insertedNum"), + info.get("Number of items inserted"), + ) + assert_resp_response( + decoded_r, + 1000, + info.get("capacity"), + info.get("Capacity"), + ) + assert_resp_response( + decoded_r, + 1, + info.get("filterNum"), + info.get("Number of filters"), + ) @pytest.mark.redismod -async def test_bf_scandump_and_loadchunk(modclient: redis.Redis): +async def test_bf_scandump_and_loadchunk(decoded_r: redis.Redis): # Store a filter - await modclient.bf().create("myBloom", "0.0001", "1000") + await decoded_r.bf().create("myBloom", "0.0001", "1000") # test is probabilistic and might fail. It is OK to change variables if # certain to not break anything async def do_verify(): res = 0 for x in range(1000): - await modclient.bf().add("myBloom", x) - rv = await modclient.bf().exists("myBloom", x) + await decoded_r.bf().add("myBloom", x) + rv = await decoded_r.bf().exists("myBloom", x) assert rv - rv = await modclient.bf().exists("myBloom", f"nonexist_{x}") + rv = await decoded_r.bf().exists("myBloom", f"nonexist_{x}") res += rv == x assert res < 5 @@ -96,52 +101,62 @@ async def do_verify(): cmds = [] if HIREDIS_AVAILABLE: with pytest.raises(ModuleError): - cur = await modclient.bf().scandump("myBloom", 0) + cur = await decoded_r.bf().scandump("myBloom", 0) return - cur = await modclient.bf().scandump("myBloom", 0) + cur = await decoded_r.bf().scandump("myBloom", 0) first = cur[0] cmds.append(cur) while True: - cur = await modclient.bf().scandump("myBloom", first) + cur = await decoded_r.bf().scandump("myBloom", first) first = cur[0] if first == 0: break else: cmds.append(cur) - prev_info = await modclient.bf().execute_command("bf.debug", "myBloom") + prev_info = await decoded_r.bf().execute_command("bf.debug", "myBloom") # Remove the filter - await modclient.bf().client.delete("myBloom") + await decoded_r.bf().client.delete("myBloom") # Now, load all the commands: for cmd in cmds: - await modclient.bf().loadchunk("myBloom", *cmd) + await decoded_r.bf().loadchunk("myBloom", *cmd) - cur_info = await modclient.bf().execute_command("bf.debug", "myBloom") + cur_info = await decoded_r.bf().execute_command("bf.debug", "myBloom") assert prev_info == cur_info await do_verify() - await modclient.bf().client.delete("myBloom") - await modclient.bf().create("myBloom", "0.0001", "10000000") + await decoded_r.bf().client.delete("myBloom") + await decoded_r.bf().create("myBloom", "0.0001", "10000000") @pytest.mark.redismod -async def test_bf_info(modclient: redis.Redis): +async def test_bf_info(decoded_r: redis.Redis): expansion = 4 # Store a filter - await modclient.bf().create("nonscaling", "0.0001", "1000", noScale=True) - info = await modclient.bf().info("nonscaling") - assert info.expansionRate is None + await decoded_r.bf().create("nonscaling", "0.0001", "1000", noScale=True) + info = await decoded_r.bf().info("nonscaling") + assert_resp_response( + decoded_r, + None, + info.get("expansionRate"), + info.get("Expansion rate"), + ) - await modclient.bf().create("expanding", "0.0001", "1000", expansion=expansion) - info = await modclient.bf().info("expanding") - assert info.expansionRate == 4 + await decoded_r.bf().create("expanding", "0.0001", "1000", expansion=expansion) + info = await decoded_r.bf().info("expanding") + assert_resp_response( + decoded_r, + 4, + info.get("expansionRate"), + info.get("Expansion rate"), + ) try: # noScale mean no expansion - await modclient.bf().create( + await decoded_r.bf().create( "myBloom", "0.0001", "1000", expansion=expansion, noScale=True ) assert False @@ -150,95 +165,96 @@ async def test_bf_info(modclient: redis.Redis): @pytest.mark.redismod -async def test_bf_card(modclient: redis.Redis): +async def test_bf_card(decoded_r: redis.Redis): # return 0 if the key does not exist - assert await modclient.bf().card("not_exist") == 0 + assert await decoded_r.bf().card("not_exist") == 0 # Store a filter - assert await modclient.bf().add("bf1", "item_foo") == 1 - assert await modclient.bf().card("bf1") == 1 + assert await decoded_r.bf().add("bf1", "item_foo") == 1 + assert await decoded_r.bf().card("bf1") == 1 - # Error when key is of a type other than Bloom filter. + # Error when key is of a type other than Bloom filtedecoded_r. with pytest.raises(redis.ResponseError): - await modclient.set("setKey", "value") - await modclient.bf().card("setKey") + await decoded_r.set("setKey", "value") + await decoded_r.bf().card("setKey") -# region Test Cuckoo Filter @pytest.mark.redismod -async def test_cf_add_and_insert(modclient: redis.Redis): - assert await modclient.cf().create("cuckoo", 1000) - assert await modclient.cf().add("cuckoo", "filter") - assert not await modclient.cf().addnx("cuckoo", "filter") - assert 1 == await modclient.cf().addnx("cuckoo", "newItem") - assert [1] == await modclient.cf().insert("captest", ["foo"]) - assert [1] == await modclient.cf().insert("captest", ["foo"], capacity=1000) - assert [1] == await modclient.cf().insertnx("captest", ["bar"]) - assert [1] == await modclient.cf().insertnx("captest", ["food"], nocreate="1") - assert [0, 0, 1] == await modclient.cf().insertnx("captest", ["foo", "bar", "baz"]) - assert [0] == await modclient.cf().insertnx("captest", ["bar"], capacity=1000) - assert [1] == await modclient.cf().insert("empty1", ["foo"], capacity=1000) - assert [1] == await modclient.cf().insertnx("empty2", ["bar"], capacity=1000) - info = await modclient.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum +async def test_cf_add_and_insert(decoded_r: redis.Redis): + assert await decoded_r.cf().create("cuckoo", 1000) + assert await decoded_r.cf().add("cuckoo", "filter") + assert not await decoded_r.cf().addnx("cuckoo", "filter") + assert 1 == await decoded_r.cf().addnx("cuckoo", "newItem") + assert [1] == await decoded_r.cf().insert("captest", ["foo"]) + assert [1] == await decoded_r.cf().insert("captest", ["foo"], capacity=1000) + assert [1] == await decoded_r.cf().insertnx("captest", ["bar"]) + assert [1] == await decoded_r.cf().insertnx("captest", ["food"], nocreate="1") + assert [0, 0, 1] == await decoded_r.cf().insertnx("captest", ["foo", "bar", "baz"]) + assert [0] == await decoded_r.cf().insertnx("captest", ["bar"], capacity=1000) + assert [1] == await decoded_r.cf().insert("empty1", ["foo"], capacity=1000) + assert [1] == await decoded_r.cf().insertnx("empty2", ["bar"], capacity=1000) + info = await decoded_r.cf().info("captest") + assert_resp_response( + decoded_r, 5, info.get("insertedNum"), info.get("Number of items inserted") + ) + assert_resp_response( + decoded_r, 0, info.get("deletedNum"), info.get("Number of items deleted") + ) + assert_resp_response( + decoded_r, 1, info.get("filterNum"), info.get("Number of filters") + ) @pytest.mark.redismod -async def test_cf_exists_and_del(modclient: redis.Redis): - assert await modclient.cf().create("cuckoo", 1000) - assert await modclient.cf().add("cuckoo", "filter") - assert await modclient.cf().exists("cuckoo", "filter") - assert not await modclient.cf().exists("cuckoo", "notexist") - assert 1 == await modclient.cf().count("cuckoo", "filter") - assert 0 == await modclient.cf().count("cuckoo", "notexist") - assert await modclient.cf().delete("cuckoo", "filter") - assert 0 == await modclient.cf().count("cuckoo", "filter") - - -# region Test Count-Min Sketch +async def test_cf_exists_and_del(decoded_r: redis.Redis): + assert await decoded_r.cf().create("cuckoo", 1000) + assert await decoded_r.cf().add("cuckoo", "filter") + assert await decoded_r.cf().exists("cuckoo", "filter") + assert not await decoded_r.cf().exists("cuckoo", "notexist") + assert 1 == await decoded_r.cf().count("cuckoo", "filter") + assert 0 == await decoded_r.cf().count("cuckoo", "notexist") + assert await decoded_r.cf().delete("cuckoo", "filter") + assert 0 == await decoded_r.cf().count("cuckoo", "filter") + + @pytest.mark.redismod -async def test_cms(modclient: redis.Redis): - assert await modclient.cms().initbydim("dim", 1000, 5) - assert await modclient.cms().initbyprob("prob", 0.01, 0.01) - assert await modclient.cms().incrby("dim", ["foo"], [5]) - assert [0] == await modclient.cms().query("dim", "notexist") - assert [5] == await modclient.cms().query("dim", "foo") - assert [10, 15] == await modclient.cms().incrby("dim", ["foo", "bar"], [5, 15]) - assert [10, 15] == await modclient.cms().query("dim", "foo", "bar") - info = await modclient.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count +async def test_cms(decoded_r: redis.Redis): + assert await decoded_r.cms().initbydim("dim", 1000, 5) + assert await decoded_r.cms().initbyprob("prob", 0.01, 0.01) + assert await decoded_r.cms().incrby("dim", ["foo"], [5]) + assert [0] == await decoded_r.cms().query("dim", "notexist") + assert [5] == await decoded_r.cms().query("dim", "foo") + assert [10, 15] == await decoded_r.cms().incrby("dim", ["foo", "bar"], [5, 15]) + assert [10, 15] == await decoded_r.cms().query("dim", "foo", "bar") + info = await decoded_r.cms().info("dim") + assert info["width"] + assert 1000 == info["width"] + assert 5 == info["depth"] + assert 25 == info["count"] @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_cms_merge(modclient: redis.Redis): - assert await modclient.cms().initbydim("A", 1000, 5) - assert await modclient.cms().initbydim("B", 1000, 5) - assert await modclient.cms().initbydim("C", 1000, 5) - assert await modclient.cms().incrby("A", ["foo", "bar", "baz"], [5, 3, 9]) - assert await modclient.cms().incrby("B", ["foo", "bar", "baz"], [2, 3, 1]) - assert [5, 3, 9] == await modclient.cms().query("A", "foo", "bar", "baz") - assert [2, 3, 1] == await modclient.cms().query("B", "foo", "bar", "baz") - assert await modclient.cms().merge("C", 2, ["A", "B"]) - assert [7, 6, 10] == await modclient.cms().query("C", "foo", "bar", "baz") - assert await modclient.cms().merge("C", 2, ["A", "B"], ["1", "2"]) - assert [9, 9, 11] == await modclient.cms().query("C", "foo", "bar", "baz") - assert await modclient.cms().merge("C", 2, ["A", "B"], ["2", "3"]) - assert [16, 15, 21] == await modclient.cms().query("C", "foo", "bar", "baz") - - -# endregion - - -# region Test Top-K +async def test_cms_merge(decoded_r: redis.Redis): + assert await decoded_r.cms().initbydim("A", 1000, 5) + assert await decoded_r.cms().initbydim("B", 1000, 5) + assert await decoded_r.cms().initbydim("C", 1000, 5) + assert await decoded_r.cms().incrby("A", ["foo", "bar", "baz"], [5, 3, 9]) + assert await decoded_r.cms().incrby("B", ["foo", "bar", "baz"], [2, 3, 1]) + assert [5, 3, 9] == await decoded_r.cms().query("A", "foo", "bar", "baz") + assert [2, 3, 1] == await decoded_r.cms().query("B", "foo", "bar", "baz") + assert await decoded_r.cms().merge("C", 2, ["A", "B"]) + assert [7, 6, 10] == await decoded_r.cms().query("C", "foo", "bar", "baz") + assert await decoded_r.cms().merge("C", 2, ["A", "B"], ["1", "2"]) + assert [9, 9, 11] == await decoded_r.cms().query("C", "foo", "bar", "baz") + assert await decoded_r.cms().merge("C", 2, ["A", "B"], ["2", "3"]) + assert [16, 15, 21] == await decoded_r.cms().query("C", "foo", "bar", "baz") + + @pytest.mark.redismod -async def test_topk(modclient: redis.Redis): +async def test_topk(decoded_r: redis.Redis): # test list with empty buckets - assert await modclient.topk().reserve("topk", 3, 50, 4, 0.9) + assert await decoded_r.topk().reserve("topk", 3, 50, 4, 0.9) assert [ None, None, @@ -257,7 +273,7 @@ async def test_topk(modclient: redis.Redis): None, "D", None, - ] == await modclient.topk().add( + ] == await decoded_r.topk().add( "topk", "A", "B", @@ -277,17 +293,17 @@ async def test_topk(modclient: redis.Redis): "E", 1, ) - assert [1, 1, 0, 0, 1, 0, 0] == await modclient.topk().query( + assert [1, 1, 0, 0, 1, 0, 0] == await decoded_r.topk().query( "topk", "A", "B", "C", "D", "E", "F", "G" ) with pytest.deprecated_call(): - assert [4, 3, 2, 3, 3, 0, 1] == await modclient.topk().count( + assert [4, 3, 2, 3, 3, 0, 1] == await decoded_r.topk().count( "topk", "A", "B", "C", "D", "E", "F", "G" ) # test full list - assert await modclient.topk().reserve("topklist", 3, 50, 3, 0.9) - assert await modclient.topk().add( + assert await decoded_r.topk().reserve("topklist", 3, 50, 3, 0.9) + assert await decoded_r.topk().add( "topklist", "A", "B", @@ -306,192 +322,196 @@ async def test_topk(modclient: redis.Redis): "E", "E", ) - assert ["A", "B", "E"] == await modclient.topk().list("topklist") - res = await modclient.topk().list("topklist", withcount=True) + assert ["A", "B", "E"] == await decoded_r.topk().list("topklist") + res = await decoded_r.topk().list("topklist", withcount=True) assert ["A", 4, "B", 3, "E", 3] == res - info = await modclient.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) + info = await decoded_r.topk().info("topklist") + assert 3 == info["k"] + assert 50 == info["width"] + assert 3 == info["depth"] + assert 0.9 == round(float(info["decay"]), 1) @pytest.mark.redismod -async def test_topk_incrby(modclient: redis.Redis): - await modclient.flushdb() - assert await modclient.topk().reserve("topk", 3, 10, 3, 1) - assert [None, None, None] == await modclient.topk().incrby( +async def test_topk_incrby(decoded_r: redis.Redis): + await decoded_r.flushdb() + assert await decoded_r.topk().reserve("topk", 3, 10, 3, 1) + assert [None, None, None] == await decoded_r.topk().incrby( "topk", ["bar", "baz", "42"], [3, 6, 2] ) - res = await modclient.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) + res = await decoded_r.topk().incrby("topk", ["42", "xyzzy"], [8, 4]) assert [None, "bar"] == res with pytest.deprecated_call(): - assert [3, 6, 10, 4, 0] == await modclient.topk().count( + assert [3, 6, 10, 4, 0] == await decoded_r.topk().count( "topk", "bar", "baz", "42", "xyzzy", 4 ) -# region Test T-Digest @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_reset(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 10) +async def test_tdigest_reset(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 10) # reset on empty histogram - assert await modclient.tdigest().reset("tDigest") + assert await decoded_r.tdigest().reset("tDigest") # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(10))) + assert await decoded_r.tdigest().add("tDigest", list(range(10))) - assert await modclient.tdigest().reset("tDigest") + assert await decoded_r.tdigest().reset("tDigest") # assert we have 0 unmerged nodes - assert 0 == (await modclient.tdigest().info("tDigest")).unmerged_nodes + info = await decoded_r.tdigest().info("tDigest") + assert_resp_response( + decoded_r, 0, info.get("unmerged_nodes"), info.get("Unmerged nodes") + ) @pytest.mark.redismod -@pytest.mark.experimental -async def test_tdigest_merge(modclient: redis.Redis): - assert await modclient.tdigest().create("to-tDigest", 10) - assert await modclient.tdigest().create("from-tDigest", 10) +@pytest.mark.onlynoncluster +async def test_tdigest_merge(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("to-tDigest", 10) + assert await decoded_r.tdigest().create("from-tDigest", 10) # insert data-points into sketch - assert await modclient.tdigest().add("from-tDigest", [1.0] * 10) - assert await modclient.tdigest().add("to-tDigest", [2.0] * 10) + assert await decoded_r.tdigest().add("from-tDigest", [1.0] * 10) + assert await decoded_r.tdigest().add("to-tDigest", [2.0] * 10) # merge from-tdigest into to-tdigest - assert await modclient.tdigest().merge("to-tDigest", 1, "from-tDigest") + assert await decoded_r.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram - info = await modclient.tdigest().info("to-tDigest") - total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) - assert 20.0 == total_weight_to + info = await decoded_r.tdigest().info("to-tDigest") + if is_resp2_connection(decoded_r): + assert 20 == float(info["merged_weight"]) + float(info["unmerged_weight"]) + else: + assert 20 == float(info["Merged weight"]) + float(info["Unmerged weight"]) # test override - assert await modclient.tdigest().create("from-override", 10) - assert await modclient.tdigest().create("from-override-2", 10) - assert await modclient.tdigest().add("from-override", [3.0] * 10) - assert await modclient.tdigest().add("from-override-2", [4.0] * 10) - assert await modclient.tdigest().merge( + assert await decoded_r.tdigest().create("from-override", 10) + assert await decoded_r.tdigest().create("from-override-2", 10) + assert await decoded_r.tdigest().add("from-override", [3.0] * 10) + assert await decoded_r.tdigest().add("from-override-2", [4.0] * 10) + assert await decoded_r.tdigest().merge( "to-tDigest", 2, "from-override", "from-override-2", override=True ) - assert 3.0 == await modclient.tdigest().min("to-tDigest") - assert 4.0 == await modclient.tdigest().max("to-tDigest") + assert 3.0 == await decoded_r.tdigest().min("to-tDigest") + assert 4.0 == await decoded_r.tdigest().max("to-tDigest") @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_min_and_max(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_min_and_max(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", [1, 2, 3]) + assert await decoded_r.tdigest().add("tDigest", [1, 2, 3]) # min/max - assert 3 == await modclient.tdigest().max("tDigest") - assert 1 == await modclient.tdigest().min("tDigest") + assert 3 == await decoded_r.tdigest().max("tDigest") + assert 1 == await decoded_r.tdigest().min("tDigest") @pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") -async def test_tdigest_quantile(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 500) +async def test_tdigest_quantile(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 500) # insert data-points into sketch - assert await modclient.tdigest().add( + assert await decoded_r.tdigest().add( "tDigest", list([x * 0.01 for x in range(1, 10000)]) ) # assert min min/max have same result as quantile 0 and 1 assert ( - await modclient.tdigest().max("tDigest") - == (await modclient.tdigest().quantile("tDigest", 1))[0] + await decoded_r.tdigest().max("tDigest") + == (await decoded_r.tdigest().quantile("tDigest", 1))[0] ) assert ( - await modclient.tdigest().min("tDigest") - == (await modclient.tdigest().quantile("tDigest", 0.0))[0] + await decoded_r.tdigest().min("tDigest") + == (await decoded_r.tdigest().quantile("tDigest", 0.0))[0] ) - assert 1.0 == round((await modclient.tdigest().quantile("tDigest", 0.01))[0], 2) - assert 99.0 == round((await modclient.tdigest().quantile("tDigest", 0.99))[0], 2) + assert 1.0 == round((await decoded_r.tdigest().quantile("tDigest", 0.01))[0], 2) + assert 99.0 == round((await decoded_r.tdigest().quantile("tDigest", 0.99))[0], 2) # test multiple quantiles - assert await modclient.tdigest().create("t-digest", 100) - assert await modclient.tdigest().add("t-digest", [1, 2, 3, 4, 5]) - res = await modclient.tdigest().quantile("t-digest", 0.5, 0.8) + assert await decoded_r.tdigest().create("t-digest", 100) + assert await decoded_r.tdigest().add("t-digest", [1, 2, 3, 4, 5]) + res = await decoded_r.tdigest().quantile("t-digest", 0.5, 0.8) assert [3.0, 5.0] == res @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_cdf(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_cdf(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(1, 10))) - assert 0.1 == round((await modclient.tdigest().cdf("tDigest", 1.0))[0], 1) - assert 0.9 == round((await modclient.tdigest().cdf("tDigest", 9.0))[0], 1) - res = await modclient.tdigest().cdf("tDigest", 1.0, 9.0) + assert await decoded_r.tdigest().add("tDigest", list(range(1, 10))) + assert 0.1 == round((await decoded_r.tdigest().cdf("tDigest", 1.0))[0], 1) + assert 0.9 == round((await decoded_r.tdigest().cdf("tDigest", 9.0))[0], 1) + res = await decoded_r.tdigest().cdf("tDigest", 1.0, 9.0) assert [0.1, 0.9] == [round(x, 1) for x in res] @pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") -async def test_tdigest_trimmed_mean(modclient: redis.Redis): - assert await modclient.tdigest().create("tDigest", 100) +async def test_tdigest_trimmed_mean(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("tDigest", 100) # insert data-points into sketch - assert await modclient.tdigest().add("tDigest", list(range(1, 10))) - assert 5 == await modclient.tdigest().trimmed_mean("tDigest", 0.1, 0.9) - assert 4.5 == await modclient.tdigest().trimmed_mean("tDigest", 0.4, 0.5) + assert await decoded_r.tdigest().add("tDigest", list(range(1, 10))) + assert 5 == await decoded_r.tdigest().trimmed_mean("tDigest", 0.1, 0.9) + assert 4.5 == await decoded_r.tdigest().trimmed_mean("tDigest", 0.4, 0.5) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_rank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(0, 20))) - assert -1 == (await modclient.tdigest().rank("t-digest", -1))[0] - assert 0 == (await modclient.tdigest().rank("t-digest", 0))[0] - assert 10 == (await modclient.tdigest().rank("t-digest", 10))[0] - assert [-1, 20, 9] == await modclient.tdigest().rank("t-digest", -20, 20, 9) +async def test_tdigest_rank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == (await decoded_r.tdigest().rank("t-digest", -1))[0] + assert 0 == (await decoded_r.tdigest().rank("t-digest", 0))[0] + assert 10 == (await decoded_r.tdigest().rank("t-digest", 10))[0] + assert [-1, 20, 9] == await decoded_r.tdigest().rank("t-digest", -20, 20, 9) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_revrank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(0, 20))) - assert -1 == (await modclient.tdigest().revrank("t-digest", 20))[0] - assert 19 == (await modclient.tdigest().revrank("t-digest", 0))[0] - assert [-1, 19, 9] == await modclient.tdigest().revrank("t-digest", 21, 0, 10) +async def test_tdigest_revrank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(0, 20))) + assert -1 == (await decoded_r.tdigest().revrank("t-digest", 20))[0] + assert 19 == (await decoded_r.tdigest().revrank("t-digest", 0))[0] + assert [-1, 19, 9] == await decoded_r.tdigest().revrank("t-digest", 21, 0, 10) @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_byrank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(1, 11))) - assert 1 == (await modclient.tdigest().byrank("t-digest", 0))[0] - assert 10 == (await modclient.tdigest().byrank("t-digest", 9))[0] - assert (await modclient.tdigest().byrank("t-digest", 100))[0] == inf +async def test_tdigest_byrank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(1, 11))) + assert 1 == (await decoded_r.tdigest().byrank("t-digest", 0))[0] + assert 10 == (await decoded_r.tdigest().byrank("t-digest", 9))[0] + assert (await decoded_r.tdigest().byrank("t-digest", 100))[0] == inf with pytest.raises(redis.ResponseError): - (await modclient.tdigest().byrank("t-digest", -1))[0] + (await decoded_r.tdigest().byrank("t-digest", -1))[0] @pytest.mark.redismod @pytest.mark.experimental -async def test_tdigest_byrevrank(modclient: redis.Redis): - assert await modclient.tdigest().create("t-digest", 500) - assert await modclient.tdigest().add("t-digest", list(range(1, 11))) - assert 10 == (await modclient.tdigest().byrevrank("t-digest", 0))[0] - assert 1 == (await modclient.tdigest().byrevrank("t-digest", 9))[0] - assert (await modclient.tdigest().byrevrank("t-digest", 100))[0] == -inf +async def test_tdigest_byrevrank(decoded_r: redis.Redis): + assert await decoded_r.tdigest().create("t-digest", 500) + assert await decoded_r.tdigest().add("t-digest", list(range(1, 11))) + assert 10 == (await decoded_r.tdigest().byrevrank("t-digest", 0))[0] + assert 1 == (await decoded_r.tdigest().byrevrank("t-digest", 9))[0] + assert (await decoded_r.tdigest().byrevrank("t-digest", 100))[0] == -inf with pytest.raises(redis.ResponseError): - (await modclient.tdigest().byrevrank("t-digest", -1))[0] + (await decoded_r.tdigest().byrevrank("t-digest", -1))[0] # @pytest.mark.redismod -# async def test_pipeline(modclient: redis.Redis): -# pipeline = await modclient.bf().pipeline() -# assert not await modclient.bf().execute_command("get pipeline") +# async def test_pipeline(decoded_r: redis.Redis): +# pipeline = await decoded_r.bf().pipeline() +# assert not await decoded_r.bf().execute_command("get pipeline") # -# assert await modclient.bf().create("pipeline", 0.01, 1000) +# assert await decoded_r.bf().create("pipeline", 0.01, 1000) # for i in range(100): # pipeline.add("pipeline", i) # for i in range(100): -# assert not (await modclient.bf().exists("pipeline", i)) +# assert not (await decoded_r.bf().exists("pipeline", i)) # # pipeline.execute() # # for i in range(100): -# assert await modclient.bf().exists("pipeline", i) +# assert await decoded_r.bf().exists("pipeline", i) diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 13e5e26ae3..1cb1fa5195 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,7 +1,6 @@ import asyncio import binascii import datetime -import os import warnings from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union from urllib.parse import urlparse @@ -9,10 +8,9 @@ import pytest import pytest_asyncio from _pytest.fixtures import FixtureRequest - +from redis._parsers import AsyncCommandsParser from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster -from redis.asyncio.connection import Connection, SSLConnection -from redis.asyncio.parser import CommandsParser +from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name @@ -31,11 +29,15 @@ ) from redis.utils import str_if_bytes from tests.conftest import ( + assert_resp_response, + is_resp2_connection, skip_if_redis_enterprise, + skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, ) +from ..ssl_utils import get_ssl_filename from .compat import mock pytestmark = pytest.mark.onlycluster @@ -49,6 +51,59 @@ ] +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.send_event = asyncio.Event() + self.server = None + self.task = None + self.n_connections = 0 + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + self.n_connections += 1 + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def aclose(self): + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + await self.server.wait_closed() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + while True: + data = await reader.read(1000) + if not data: + break + writer.write(data) + await writer.drain() + + @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -99,12 +154,12 @@ async def execute_command(*_args, **_kwargs): execute_command_mock.side_effect = execute_command with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: self.commands = { - "GET": { + "get": { "name": "get", "arity": 2, "flags": ["readonly", "fast"], @@ -549,12 +604,12 @@ def map_7007(self): mocks["send_packed_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: self.commands = { - "GET": { + "get": { "name": "get", "arity": 2, "flags": ["readonly", "fast"], @@ -765,6 +820,8 @@ async def test_not_require_full_coverage_cluster_down_error( assert all(await r.cluster_delslots(missing_slot)) with pytest.raises(ClusterDownError): await r.exists("foo") + except ResponseError as e: + assert "CLUSTERDOWN" in str(e) finally: try: # Add back the missing slot @@ -809,6 +866,49 @@ async def test_default_node_is_replaced_after_exception(self, r): # Rollback to the old default node r.replace_default_node(curr_default_node) + async def test_address_remap(self, create_redis, master_host): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + hostname, master_port = master_host + ports = [master_port + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports + ] + await asyncio.gather(*[p.start() for p in proxies]) + try: + # create cluster: + r = await create_redis( + cls=RedisCluster, flushdb=False, address_remap=address_remap + ) + try: + assert await r.ping() is True + assert await r.set("byte_string", b"giraffe") + assert await r.get("byte_string") == b"giraffe" + finally: + await r.close() + finally: + await asyncio.gather(*[p.aclose() for p in proxies]) + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + class TestClusterRedisCommands: """ @@ -868,7 +968,7 @@ async def test_client_setname(self, r: RedisCluster) -> None: node = r.get_random_node() await r.client_setname("redis_py_test", target_nodes=node) client_name = await r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" + assert_resp_response(r, client_name, "redis_py_test", b"redis_py_test") async def test_exists(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} @@ -910,6 +1010,13 @@ async def test_cluster_myid(self, r: RedisCluster) -> None: myid = await r.cluster_myid(node) assert len(myid) == 40 + @skip_if_server_version_lt("7.2.0") + @skip_if_redis_enterprise() + async def test_cluster_myshardid(self, r: RedisCluster) -> None: + node = r.get_random_node() + myshardid = await r.cluster_myshardid(node) + assert len(myshardid) == 40 + @skip_if_redis_enterprise() async def test_cluster_slots(self, r: RedisCluster) -> None: mock_all_nodes_resp(r, default_cluster_slots) @@ -962,11 +1069,14 @@ async def test_cluster_delslots(self) -> None: @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() - async def test_cluster_delslotsrange(self, r: RedisCluster): + async def test_cluster_delslotsrange(self): + r = await get_mocked_redis_client(host=default_host, port=default_port) + mock_all_nodes_resp(r, "OK") node = r.get_random_node() - mock_node_resp(node, "OK") await r.cluster_addslots(node, 1, 2, 3, 4, 5) assert await r.cluster_delslotsrange(1, 5) + assert node._free.pop().read_response.called + await r.close() @skip_if_redis_enterprise() async def test_cluster_failover(self, r: RedisCluster) -> None: @@ -1152,11 +1262,18 @@ async def test_cluster_replicas(self, r: RedisCluster) -> None: async def test_cluster_links(self, r: RedisCluster): node = r.get_random_node() res = await r.cluster_links(node) - links_to = sum(x.count("to") for x in res) - links_for = sum(x.count("from") for x in res) - assert links_to == links_for - for i in range(0, len(res) - 1, 2): - assert res[i][3] == res[i + 1][3] + if is_resp2_connection(r): + links_to = sum(x.count(b"to") for x in res) + links_for = sum(x.count(b"from") for x in res) + assert links_to == links_for + for i in range(0, len(res) - 1, 2): + assert res[i][3] == res[i + 1][3] + else: + links_to = len(list(filter(lambda x: x[b"direction"] == b"to", res))) + links_for = len(list(filter(lambda x: x[b"direction"] == b"from", res))) + assert links_to == links_for + for i in range(0, len(res) - 1, 2): + assert res[i][b"node"] == res[i + 1][b"node"] @skip_if_redis_enterprise() async def test_readonly(self) -> None: @@ -1340,7 +1457,7 @@ async def test_client_trackinginfo(self, r: RedisCluster) -> None: node = r.get_primaries()[0] res = await r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @skip_if_server_version_lt("2.9.50") async def test_client_pause(self, r: RedisCluster) -> None: @@ -1506,24 +1623,68 @@ async def test_cluster_renamenx(self, r: RedisCluster) -> None: async def test_cluster_blpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) is None await r.rpush("{foo}c", "1") - assert await r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, await r.blpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) async def test_cluster_brpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) is None await r.rpush("{foo}c", "1") - assert await r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, await r.brpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) async def test_cluster_brpoplpush(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") @@ -1596,7 +1757,8 @@ async def test_cluster_zdiff(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert await r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + response = await r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: @@ -1604,7 +1766,8 @@ async def test_cluster_zdiffstore(self, r: RedisCluster) -> None: await r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert await r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert await r.zrange("{foo}out", 0, -1) == [b"a3"] - assert await r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + response = await r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") async def test_cluster_zinter(self, r: RedisCluster) -> None: @@ -1618,32 +1781,41 @@ async def test_cluster_zinter(self, r: RedisCluster) -> None: ["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True ) # aggregate with SUM - assert await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8], [b"a1", 9]] + ) # aggregate with MAX - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] + ) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) # aggregate with MIN - assert await r.zinter( + response = await r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] + ) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 1)], [[b"a1", 1], [b"a3", 1]] + ) # with weights - assert await r.zinter( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a3", 20), (b"a1", 23)] + res = await r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) + assert_resp_response( + r, res, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) async def test_cluster_zinterstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1655,10 +1827,12 @@ async def test_cluster_zinterstore_max(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1670,10 +1844,12 @@ async def test_cluster_zinterstore_min(self, r: RedisCluster) -> None: ) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1682,66 +1858,86 @@ async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b2", - 20, - ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b1", - 10, - ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a2", - 2, + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a1", - 1, + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], ) assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None await r.zadd("{foo}c", {"c1": 100}) - assert await r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + await r.bzpopmax("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmin(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b1", - 10, - ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b2", - 20, - ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a1", - 1, + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a2", - 2, + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], + ) + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], ) assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None await r.zadd("{foo}c", {"c1": 100}) - assert await r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + await r.bzpopmin("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("6.2.0") async def test_cluster_zrangestore(self, r: RedisCluster) -> None: @@ -1750,10 +1946,12 @@ async def test_cluster_zrangestore(self, r: RedisCluster) -> None: assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert await r.zrangestore("{foo}b", "{foo}a", 1, 2) assert await r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert await r.zrange("{foo}b", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a3", 3), - ] + assert_resp_response( + r, + await r.zrange("{foo}b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert await r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert await r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1780,36 +1978,49 @@ async def test_cluster_zunion(self, r: RedisCluster) -> None: b"a3", b"a1", ] - assert await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert await r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + await r.zunion( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert await r.zunion( - {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True - ) == [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + await r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) async def test_cluster_zunionstore_sum(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) await r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1821,12 +2032,12 @@ async def test_cluster_zunionstore_max(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1838,12 +2049,12 @@ async def test_cluster_zunionstore_min(self, r: RedisCluster) -> None: ) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1852,12 +2063,12 @@ async def test_cluster_zunionstore_with_weight(self, r: RedisCluster) -> None: assert ( await r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 ) - assert await r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + await r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") async def test_cluster_pfcount(self, r: RedisCluster) -> None: @@ -2083,7 +2294,7 @@ async def test_acl_log( await user_client.hset("{cache}:0", "hkey", "hval") assert isinstance(await r.acl_log(target_nodes=node), list) - assert len(await r.acl_log(target_nodes=node)) == 2 + assert len(await r.acl_log(target_nodes=node)) == 3 assert len(await r.acl_log(count=1, target_nodes=node)) == 1 assert isinstance((await r.acl_log(target_nodes=node))[0], dict) assert "client-info" in (await r.acl_log(count=1, target_nodes=node))[0] @@ -2341,7 +2552,7 @@ async def mocked_execute_command(self, *args, **kwargs): assert "Redis Cluster cannot be connected" in str(e.value) with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -2547,6 +2758,7 @@ async def test_asking_error(self, r: RedisCluster) -> None: assert ask_node._free.pop().read_response.await_count assert res == ["MOCK_OK"] + @skip_if_server_version_gte("7.0.0") async def test_moved_redirection_on_slave_with_default( self, r: RedisCluster ) -> None: @@ -2641,17 +2853,8 @@ class TestSSL: appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "../..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "docker", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") @pytest_asyncio.fixture() def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 7c6fd45ab9..7808d171fa 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1,23 +1,38 @@ """ Tests async overrides of commands from their mixins """ +import asyncio import binascii import datetime import re +import sys from string import ascii_letters import pytest import pytest_asyncio - import redis from redis import exceptions -from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + parse_info, +) +from redis.client import EMPTY_RESPONSE, NEVER_DECODE from tests.conftest import ( + assert_resp_response, + assert_resp_response_in, + is_resp2_connection, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, ) +if sys.version_info >= (3, 11, 3): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + REDIS_6_VERSION = "5.9.0" @@ -71,14 +86,19 @@ class TestResponseCallbacks: """Tests for the response callback system""" async def test_response_callbacks(self, r: redis.Redis): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS - assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + callbacks = _RedisCallbacks + if is_resp2_connection(r): + callbacks.update(_RedisCallbacksRESP2) + else: + callbacks.update(_RedisCallbacksRESP3) + assert r.response_callbacks == callbacks + assert id(r.response_callbacks) != id(_RedisCallbacks) r.set_response_callback("GET", lambda x: "static") await r.set("a", "foo") assert await r.get("a") == "static" async def test_case_insensitive_command_names(self, r: redis.Redis): - assert r.response_callbacks["del"] == r.response_callbacks["DEL"] + assert r.response_callbacks["ping"] == r.response_callbacks["PING"] class TestRedisCommands: @@ -92,13 +112,13 @@ async def test_command_on_invalid_key_type(self, r: redis.Redis): async def test_acl_cat_no_category(self, r: redis.Redis): categories = await r.acl_cat() assert isinstance(categories, list) - assert "read" in categories + assert "read" in categories or b"read" in categories @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_cat_with_category(self, r: redis.Redis): commands = await r.acl_cat("read") assert isinstance(commands, list) - assert "get" in commands + assert "get" in commands or b"get" in commands @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_deluser(self, r_teardown): @@ -112,36 +132,32 @@ async def test_acl_deluser(self, r_teardown): @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_genpass(self, r: redis.Redis): password = await r.acl_genpass() - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) - @skip_if_server_version_lt(REDIS_6_VERSION) - @skip_if_server_version_gte("7.0.0") + @skip_if_server_version_lt("7.0.0") async def test_acl_getuser_setuser(self, r_teardown): username = "redis-py-user" r = r_teardown(username) # test enabled=False assert await r.acl_setuser(username, enabled=False, reset=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": False, - "flags": ["off", "allchannels", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "off" in acl["flags"] + assert acl["enabled"] is False # test nopass=True assert await r.acl_setuser(username, enabled=True, reset=True, nopass=True) - assert await r.acl_getuser(username) == { - "categories": ["-@all"], - "commands": [], - "channels": [b"*"], - "enabled": True, - "flags": ["on", "allchannels", "nopass", "sanitize-payload"], - "keys": [], - "passwords": [], - } + acl = await r.acl_getuser(username) + assert acl["categories"] == ["-@all"] + assert acl["commands"] == [] + assert acl["keys"] == [] + assert acl["passwords"] == [] + assert "on" in acl["flags"] + assert "nopass" in acl["flags"] + assert acl["enabled"] is True # test all args assert await r.acl_setuser( @@ -154,12 +170,11 @@ async def test_acl_getuser_setuser(self, r_teardown): keys=["cache:*", "objects:*"], ) acl = await r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True - assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert "on" in acl["flags"] + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 # test reset=False keeps existing ACL and applies new ACL on top @@ -181,12 +196,10 @@ async def test_acl_getuser_setuser(self, r_teardown): keys=["objects:*"], ) acl = await r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True - assert acl["channels"] == [b"*"] - assert acl["flags"] == ["on", "allchannels", "sanitize-payload"] - assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert "on" in acl["flags"] + assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 # test removal of passwords @@ -222,14 +235,13 @@ async def test_acl_getuser_setuser(self, r_teardown): assert len((await r.acl_getuser(username))["passwords"]) == 1 @skip_if_server_version_lt(REDIS_6_VERSION) - @skip_if_server_version_gte("7.0.0") async def test_acl_list(self, r_teardown): username = "redis-py-user" r = r_teardown(username) - + start = await r.acl_list() assert await r.acl_setuser(username, enabled=False, reset=True) users = await r.acl_list() - assert f"user {username} off sanitize-payload &* -@all" in users + assert len(users) == len(start) + 1 @skip_if_server_version_lt(REDIS_6_VERSION) @pytest.mark.onlynoncluster @@ -261,10 +273,11 @@ async def test_acl_log(self, r_teardown, create_redis): await user_client.hset("cache:0", "hkey", "hval") assert isinstance(await r.acl_log(), list) - assert len(await r.acl_log()) == 2 + assert len(await r.acl_log()) == 3 assert len(await r.acl_log(count=1)) == 1 assert isinstance((await r.acl_log())[0], dict) - assert "client-info" in (await r.acl_log(count=1))[0] + expected = (await r.acl_log(count=1))[0] + assert_resp_response_in(r, "client-info", expected, expected.keys()) assert await r.acl_log_reset() @skip_if_server_version_lt(REDIS_6_VERSION) @@ -300,7 +313,7 @@ async def test_acl_users(self, r: redis.Redis): @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_whoami(self, r: redis.Redis): username = await r.acl_whoami() - assert isinstance(username, str) + assert isinstance(username, (str, bytes)) @pytest.mark.onlynoncluster async def test_client_list(self, r: redis.Redis): @@ -338,7 +351,29 @@ async def test_client_getname(self, r: redis.Redis): @pytest.mark.onlynoncluster async def test_client_setname(self, r: redis.Redis): assert await r.client_setname("redis_py_test") - assert await r.client_getname() == "redis_py_test" + assert_resp_response( + r, await r.client_getname(), "redis_py_test", b"redis_py_test" + ) + + @skip_if_server_version_lt("7.2.0") + async def test_client_setinfo(self, r: redis.Redis): + await r.ping() + info = await r.client_info() + assert info["lib-name"] == "redis-py" + assert info["lib-ver"] == redis.__version__ + assert await r.client_setinfo("lib-name", "test") + assert await r.client_setinfo("lib-ver", "123") + info = await r.client_info() + assert info["lib-name"] == "test" + assert info["lib-ver"] == "123" + r2 = redis.asyncio.Redis(lib_name="test2", lib_version="1234") + info = await r2.client_info() + assert info["lib-name"] == "test2" + assert info["lib-ver"] == "1234" + r3 = redis.asyncio.Redis(lib_name=None, lib_version=None) + info = await r3.client_info() + assert info["lib-name"] == "" + assert info["lib-ver"] == "" @skip_if_server_version_lt("2.6.9") @pytest.mark.onlynoncluster @@ -446,6 +481,28 @@ async def test_client_pause(self, r: redis.Redis): with pytest.raises(exceptions.RedisError): await r.client_pause(timeout="not an integer") + @skip_if_server_version_lt("7.2.0") + @pytest.mark.onlynoncluster + async def test_client_no_touch(self, r: redis.Redis): + assert await r.client_no_touch("ON") == b"OK" + assert await r.client_no_touch("OFF") == b"OK" + with pytest.raises(TypeError): + await r.client_no_touch() + + @skip_if_server_version_lt("7.2.0") + @pytest.mark.onlycluster + async def test_waitaof(self, r): + # must return a list of 2 elements + assert len(await r.waitaof(0, 0, 0)) == 2 + assert len(await r.waitaof(1, 0, 0)) == 2 + assert len(await r.waitaof(1, 0, 1000)) == 2 + + # value is out of range, value must between 0 and 1 + with pytest.raises(exceptions.ResponseError): + await r.waitaof(2, 0, 0) + with pytest.raises(exceptions.ResponseError): + await r.waitaof(-1, 0, 0) + async def test_config_get(self, r: redis.Redis): data = await r.config_get() assert "maxmemory" in data @@ -915,6 +972,19 @@ async def test_pttl_no_key(self, r: redis.Redis): """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" assert await r.pttl("a") == -2 + @skip_if_server_version_lt("6.2.0") + async def test_hrandfield(self, r): + assert await r.hrandfield("key") is None + await r.hset("key", mapping={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + assert await r.hrandfield("key") is not None + assert len(await r.hrandfield("key", 2)) == 2 + # with values + assert_resp_response(r, len(await r.hrandfield("key", 2, True)), 4, 2) + # without duplications + assert len(await r.hrandfield("key", 10)) == 5 + # with duplications + assert len(await r.hrandfield("key", -10)) == 10 + @pytest.mark.onlynoncluster async def test_randomkey(self, r: redis.Redis): assert await r.randomkey() is None @@ -1051,25 +1121,45 @@ async def test_type(self, r: redis.Redis): async def test_blpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") - assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"3") - assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"4") - assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"1") - assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) assert await r.blpop(["b", "a"], timeout=1) is None await r.rpush("c", "1") - assert await r.blpop("c", timeout=1) == (b"c", b"1") + assert_resp_response( + r, await r.blpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"] + ) @pytest.mark.onlynoncluster async def test_brpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") - assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"4") - assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"3") - assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"2") - assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) assert await r.brpop(["b", "a"], timeout=1) is None await r.rpush("c", "1") - assert await r.brpop("c", timeout=1) == (b"c", b"1") + assert_resp_response( + r, await r.brpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"] + ) @pytest.mark.onlynoncluster async def test_brpoplpush(self, r: redis.Redis): @@ -1374,7 +1464,10 @@ async def test_spop_multi_value(self, r: redis.Redis): for value in values: assert value in s - assert await r.spop("a", 1) == list(set(s) - set(values)) + response = await r.spop("a", 1) + assert_resp_response( + r, response, list(set(s) - set(values)), set(s) - set(values) + ) async def test_srandmember(self, r: redis.Redis): s = [b"1", b"2", b"3"] @@ -1412,11 +1505,13 @@ async def test_sunionstore(self, r: redis.Redis): async def test_zadd(self, r: redis.Redis): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} await r.zadd("a", mapping) - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -1433,23 +1528,24 @@ async def test_zadd(self, r: redis.Redis): async def test_zadd_nx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) async def test_zadd_xx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert await r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a1", 99.0)], [[b"a1", 99.0]]) async def test_zadd_ch(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 99.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 99.0)], [[b"a2", 2.0], [b"a1", 99.0]] + ) async def test_zadd_incr(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 @@ -1473,6 +1569,25 @@ async def test_zcount(self, r: redis.Redis): assert await r.zcount("a", 1, "(" + str(2)) == 1 assert await r.zcount("a", 10, 20) == 0 + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiff(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiff(["a", "b"]) == [b"a3"] + response = await r.zdiff(["a", "b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiffstore(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiffstore("out", ["a", "b"]) + assert await r.zrange("out", 0, -1) == [b"a3"] + response = await r.zrange("out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) + async def test_zincrby(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zincrby("a", 1, "a2") == 3.0 @@ -1492,7 +1607,10 @@ async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"]) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8.0], [b"a1", 9.0]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): @@ -1500,7 +1618,10 @@ async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): @@ -1508,7 +1629,10 @@ async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 3)], [[b"a1", 1], [b"a3", 3]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): @@ -1516,49 +1640,104 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmax("a") == [(b"a3", 3)] + response = await r.zpopmax("a") + assert_resp_response(r, response, [(b"a3", 3)], [b"a3", 3.0]) # with count - assert await r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + response = await r.zpopmax("a", count=2) + assert_resp_response( + r, response, [(b"a2", 2), (b"a1", 1)], [[b"a2", 2], [b"a1", 1]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmin("a") == [(b"a1", 1)] + response = await r.zpopmin("a") + assert_resp_response(r, response, [(b"a1", 1)], [b"a1", 1.0]) # with count - assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + response = await r.zpopmin("a", count=2) + assert_resp_response( + r, response, [(b"a2", 2), (b"a3", 3)], [[b"a2", 2], [b"a3", 3]] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster async def test_bzpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"a", b"a2", 2), + [b"a", b"a2", 2], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"a", b"a1", 1), + [b"a", b"a1", 1], + ) assert await r.bzpopmax(["b", "a"], timeout=1) is None await r.zadd("c", {"c1": 100}) - assert await r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, await r.bzpopmax("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster async def test_bzpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"a", b"a2", 2), + [b"a", b"a2", 2], + ) assert await r.bzpopmin(["b", "a"], timeout=1) is None await r.zadd("c", {"c1": 100}) - assert await r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, await r.bzpopmin("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) async def test_zrange(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1566,20 +1745,20 @@ async def test_zrange(self, r: redis.Redis): assert await r.zrange("a", 1, 2) == [b"a2", b"a3"] # withscores - assert await r.zrange("a", 0, 1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] - assert await r.zrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, 1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) + response = await r.zrange("a", 1, 2, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]] + ) # custom score function - assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + # assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] @skip_if_server_version_lt("2.8.9") async def test_zrangebylex(self, r: redis.Redis): @@ -1613,16 +1792,24 @@ async def test_zrangebyscore(self, r: redis.Redis): assert await r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert await r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] + response = await r.zrangebyscore("a", 2, 4, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) # custom score function - assert await r.zrangebyscore( + response = await r.zrangebyscore( "a", 2, 4, withscores=True, score_cast_func=int - ) == [(b"a2", 2), (b"a3", 3), (b"a4", 4)] + ) + assert_resp_response( + r, + response, + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) async def test_zrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1630,6 +1817,15 @@ async def test_zrank(self, r: redis.Redis): assert await r.zrank("a", "a2") == 1 assert await r.zrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + async def test_zrank_withscore(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrank("a", "a1") == 0 + assert await r.rank("a", "a2") == 1 + assert await r.zrank("a", "a6") is None + assert await r.zrank("a", "a3", withscore=True) == [2, "3"] + assert await r.zrank("a", "a6", withscore=True) is None + async def test_zrem(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zrem("a", "a2") == 1 @@ -1670,20 +1866,20 @@ async def test_zrevrange(self, r: redis.Redis): assert await r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert await r.zrevrange("a", 0, 1, withscores=True) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - assert await r.zrevrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 1.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True) + assert_resp_response( + r, response, [(b"a3", 3.0), (b"a2", 2.0)], [[b"a3", 3.0], [b"a2", 2.0]] + ) + response = await r.zrevrange("a", 1, 2, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 1.0)], [[b"a2", 2.0], [b"a1", 1.0]] + ) # custom score function - assert await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) + assert_resp_response( + r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]] + ) async def test_zrevrangebyscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1693,16 +1889,24 @@ async def test_zrevrangebyscore(self, r: redis.Redis): assert await r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] # withscores - assert await r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrangebyscore("a", 4, 2, withscores=True) + assert_resp_response( + r, + response, + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) # custom score function - assert await r.zrevrangebyscore( + response = await r.zrevrangebyscore( "a", 4, 2, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + ) + assert_resp_response( + r, + response, + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) async def test_zrevrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1710,6 +1914,15 @@ async def test_zrevrank(self, r: redis.Redis): assert await r.zrevrank("a", "a2") == 3 assert await r.zrevrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + async def test_zrevrank_withscore(self, r: redis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrevrank("a", "a1") == 4 + assert await r.zrevrank("a", "a2") == 3 + assert await r.zrevrank("a", "a6") is None + assert await r.zrevrank("a", "a3", withscore=True) == [2, "3"] + assert await r.zrevrank("a", "a6", withscore=True) is None + async def test_zscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zscore("a", "a1") == 1.0 @@ -1722,12 +1935,13 @@ async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"]) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 3.0), (b"a4", 4.0), (b"a3", 8.0), (b"a1", 9.0)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): @@ -1735,12 +1949,13 @@ async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + respponse = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + respponse, + [(b"a2", 2.0), (b"a4", 4.0), (b"a3", 5.0), (b"a1", 6.0)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): @@ -1748,12 +1963,13 @@ async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): @@ -1761,12 +1977,13 @@ async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a2", 5.0), (b"a4", 12.0), (b"a3", 20.0), (b"a1", 23.0)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) # HYPERLOGLOG TESTS @skip_if_server_version_lt("2.8.9") @@ -2207,11 +2424,12 @@ async def test_geohash(self, r: redis.Redis): ) await r.geoadd("barcelona", values) - assert await r.geohash("barcelona", "place1", "place2", "place3") == [ - "sp3e9yg3kd0", - "sp3e9cbc3t0", - None, - ] + assert_resp_response( + r, + await r.geohash("barcelona", "place1", "place2", "place3"), + ["sp3e9yg3kd0", "sp3e9cbc3t0", None], + [b"sp3e9yg3kd0", b"sp3e9cbc3t0", None], + ) @skip_if_server_version_lt("3.2.0") async def test_geopos(self, r: redis.Redis): @@ -2223,10 +2441,18 @@ async def test_geopos(self, r: redis.Redis): await r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert await r.geopos("barcelona", "place1", "place2") == [ - (2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099), - ] + assert_resp_response( + r, + await r.geopos("barcelona", "place1", "place2"), + [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ], + [ + [2.19093829393386841, 41.43379028184083523], + [2.18737632036209106, 41.40634178640635099], + ], + ) @skip_if_server_version_lt("4.0.0") async def test_geopos_no_value(self, r: redis.Redis): @@ -2635,7 +2861,7 @@ async def test_xgroup_setid(self, r: redis.Redis): ] assert await r.xinfo_groups(stream) == expected - @skip_if_server_version_lt("5.0.0") + @skip_if_server_version_lt("7.2.0") async def test_xinfo_consumers(self, r: redis.Redis): stream = "stream" group = "group" @@ -2651,8 +2877,8 @@ async def test_xinfo_consumers(self, r: redis.Redis): info = await r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ - {"name": consumer1.encode(), "pending": 1}, - {"name": consumer2.encode(), "pending": 2}, + {"name": consumer1.encode(), "pending": 1, "inactive": 2}, + {"name": consumer2.encode(), "pending": 2, "inactive": 2}, ] # we can't determine the idle time, so just make sure it's an int @@ -2761,28 +2987,30 @@ async def test_xread(self, r: redis.Redis): m1 = await r.xadd(stream, {"foo": "bar"}) m2 = await r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert await r.xread(streams={stream: 0}) == expected + res = await r.xread(streams={stream: 0}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert await r.xread(streams={stream: 0}, count=1) == expected + res = await r.xread(streams={stream: 0}, count=1) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) - expected = [[stream.encode(), [await get_stream_message(r, stream, m2)]]] + expected_entries = [await get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert await r.xread(streams={stream: m1}) == expected - - # xread starting at the last message returns an empty list - assert await r.xread(streams={stream: m2}) == [] + res = await r.xread(streams={stream: m1}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xreadgroup(self, r: redis.Redis): @@ -2793,26 +3021,27 @@ async def test_xreadgroup(self, r: redis.Redis): m2 = await r.xadd(stream, {"bing": "baz"}) await r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert ( - await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} ) await r.xgroup_destroy(stream, group) @@ -2821,35 +3050,34 @@ async def test_xreadgroup(self, r: redis.Redis): # will only find messages added after this await r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response(r, res, [], {}) # xreadgroup with noack does not have any items in the PEL await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") - assert ( - len( - ( - await r.xreadgroup( - group, consumer, streams={stream: ">"}, noack=True - ) - )[0][1] - ) - == 2 - ) - # now there should be nothing pending - assert ( - len((await r.xreadgroup(group, consumer, streams={stream: "0"}))[0][1]) == 0 - ) + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[strem_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[strem_name][0]) == 0 await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) - assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xrevrange(self, r: redis.Redis): @@ -2999,6 +3227,37 @@ async def test_module_list(self, r: redis.Redis): for x in await r.module_list(): assert isinstance(x, dict) + @pytest.mark.onlynoncluster + async def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + ready = asyncio.Event() + + async def helper(): + with pytest.raises(asyncio.CancelledError): + # blocking pop + ready.set() + await r.brpop(["nonexist"]) + # If the following is not done, further Timout operations will fail, + # because the timeout won't catch its Cancelled Error if the task + # has a pending cancel. Python documentation probably should reflect this. + if sys.version_info >= (3, 11): + asyncio.current_task().uncancel() + # if all is well, we can continue. The following should not hang. + await r.set("status", "down") + + task = asyncio.create_task(helper()) + await ready.wait() + await asyncio.sleep(0.01) + # the task is now sleeping, lets send it an exception + task.cancel() + # If all is well, the task should finish right away, otherwise fail with Timeout + async with async_timeout(1.0): + await task + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py new file mode 100644 index 0000000000..bead7208f5 --- /dev/null +++ b/tests/test_asyncio/test_connect.py @@ -0,0 +1,144 @@ +import asyncio +import logging +import re +import socket +import ssl + +import pytest +from redis.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, +) + +from ..ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +async def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + await _assert_connect(conn, tcp_address) + + +async def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection( + path=path, client_name=_CLIENT_NAME, socket_timeout=10 + ) + await _assert_connect(conn, path) + + +@pytest.mark.ssl +async def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +async def _assert_connect(conn, server_address, certfile=None, keyfile=None): + stop_event = asyncio.Event() + finished = asyncio.Event() + + async def _handler(reader, writer): + try: + return await _redis_request_handler(reader, writer, stop_event) + finally: + finished.set() + + if isinstance(server_address, str): + server = await asyncio.start_unix_server(_handler, path=server_address) + elif certfile: + host, port = server_address + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.load_cert_chain(certfile=certfile, keyfile=keyfile) + server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) + else: + host, port = server_address + server = await asyncio.start_server(_handler, host=host, port=port) + + async with server as aserver: + await aserver.start_serving() + try: + await conn.connect() + await conn.disconnect() + finally: + stop_event.set() + aserver.close() + await aserver.wait_closed() + await finished.wait() + + +async def _redis_request_handler(reader, writer, stop_event): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while not stop_event.is_set() or buffer: + _logger.info(str(stop_event.is_set())) + try: + buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) + except TimeoutError: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response from %s", resp) + writer.write(resp) + await writer.drain() + command = None + _logger.info("Exit handler") diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index bf59dbe6b0..d1aad796e7 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -4,17 +4,19 @@ from unittest.mock import patch import pytest - import redis -from redis.asyncio.connection import ( - BaseParser, - Connection, - PythonParser, - UnixDomainSocketConnection, +from redis._parsers import ( + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, + _AsyncRESPBase, ) +from redis.asyncio import Redis +from redis.asyncio.connection import Connection, UnixDomainSocketConnection, parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt from .compat import mock @@ -28,11 +30,11 @@ async def test_invalid_response(create_redis): raw = b"x" fake_stream = MockStream(raw + b"\r\n") - parser: BaseParser = r.connection._parser + parser: _AsyncRESPBase = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - if isinstance(parser, PythonParser): + if isinstance(parser, _AsyncRESPBase): assert str(cm.value) == f"Protocol Error: {raw!r}" else: assert ( @@ -41,25 +43,69 @@ async def test_invalid_response(create_redis): await r.connection.disconnect() +@pytest.mark.onlynoncluster +async def test_single_connection(): + """Test that concurrent requests on a single client are synchronised.""" + r = Redis(single_connection_client=True) + + init_call_count = 0 + command_call_count = 0 + in_use = False + + class Retry_: + async def call_with_retry(self, _, __): + # If we remove the single-client lock, this error gets raised as two + # coroutines will be vying for the `in_use` flag due to the two + # asymmetric sleep calls + nonlocal command_call_count + nonlocal in_use + if in_use is True: + raise ValueError("Commands should be executed one at a time.") + in_use = True + await asyncio.sleep(0.01) + command_call_count += 1 + await asyncio.sleep(0.03) + in_use = False + return "foo" + + mock_conn = mock.MagicMock() + mock_conn.retry = Retry_() + + async def get_conn(_): + # Validate only one client is created in single-client mode when + # concurrent requests are made + nonlocal init_call_count + await asyncio.sleep(0.01) + init_call_count += 1 + return mock_conn + + with mock.patch.object(r.connection_pool, "get_connection", get_conn): + with mock.patch.object(r.connection_pool, "release"): + await asyncio.gather(r.set("a", "b"), r.set("c", "d")) + + assert init_call_count == 1 + assert command_call_count == 2 + + @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_loading_external_modules(modclient): +async def test_loading_external_modules(r): def inner(): pass - modclient.load_external_module("myfuncname", inner) - assert getattr(modclient, "myfuncname") == inner - assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + r.load_external_module("myfuncname", inner) + assert getattr(r, "myfuncname") == inner + assert isinstance(getattr(r, "myfuncname"), types.FunctionType) # and call it from redis.commands import RedisModuleCommands j = RedisModuleCommands.json - modclient.load_external_module("sometestfuncname", j) + r.load_external_module("sometestfuncname", j) # d = {'hello': 'world!'} - # mod = j(modclient) + # mod = j(r) # mod.set("fookey", ".", d) # assert mod.get('fookey') == d @@ -79,9 +125,11 @@ async def test_can_run_concurrent_commands(r): assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) -async def test_connect_retry_on_timeout_error(): +async def test_connect_retry_on_timeout_error(connect_args): """Test that the _connect function is retried in case of a timeout""" - conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3)) + conn = Connection( + retry_on_timeout=True, retry=Retry(NoBackoff(), 3), **connect_args + ) origin_connect = conn._connect conn._connect = mock.AsyncMock() @@ -137,7 +185,7 @@ async def test_connection_parse_response_resume(r: redis.Redis): conn._parser._stream = MockStream(message, interrupt_every=2) for i in range(100): try: - response = await conn.read_response() + response = await conn.read_response(disconnect_on_error=False) break except MockStream.TestError: pass @@ -146,3 +194,132 @@ async def test_connection_parse_response_resume(r: redis.Redis): pytest.fail("didn't receive a response") assert response assert i > 0 + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "parser_class", + [_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser], + ids=["AsyncRESP2Parser", "AsyncRESP3Parser", "AsyncHiredisParser"], +) +async def test_connection_disconect_race(parser_class, connect_args): + """ + This test reproduces the case in issue #2349 + where a connection is closed while the parser is reading to feed the + internal buffer.The stream `read()` will succeed, but when it returns, + another task has already called `disconnect()` and is waiting for + close to finish. When we attempts to feed the buffer, we will fail + since the buffer is no longer there. + + This test verifies that a read in progress can finish even + if the `disconnect()` method is called. + """ + if parser_class == _AsyncHiredisParser and not HIREDIS_AVAILABLE: + pytest.skip("Hiredis not available") + + connect_args["parser_class"] = parser_class + + conn = Connection(**connect_args) + + cond = asyncio.Condition() + # 0 == initial + # 1 == reader is reading + # 2 == closer has closed and is waiting for close to finish + state = 0 + + # Mock read function, which wait for a close to happen before returning + # Can either be invoked as two `read()` calls (HiredisParser) + # or as a `readline()` followed by `readexact()` (PythonParser) + chunks = [b"$13\r\n", b"Hello, World!\r\n"] + + async def read(_=None): + nonlocal state + async with cond: + if state == 0: + state = 1 # we are reading + cond.notify() + # wait until the closing task has done + await cond.wait_for(lambda: state == 2) + return chunks.pop(0) + + # function closes the connection while reader is still blocked reading + async def do_close(): + nonlocal state + async with cond: + await cond.wait_for(lambda: state == 1) + state = 2 + cond.notify() + await conn.disconnect() + + async def do_read(): + return await conn.read_response() + + reader = mock.AsyncMock() + writer = mock.AsyncMock() + writer.transport = mock.Mock() + writer.transport.get_extra_info.side_effect = None + + # for HiredisParser + reader.read.side_effect = read + # for PythonParser + reader.readline.side_effect = read + reader.readexactly.side_effect = read + + async def open_connection(*args, **kwargs): + return reader, writer + + async def dummy_method(*args, **kwargs): + pass + + # get dummy stream objects for the connection + with patch.object(asyncio, "open_connection", open_connection): + # disable the initial version handshake + with patch.multiple( + conn, send_command=dummy_method, read_response=dummy_method + ): + await conn.connect() + + vals = await asyncio.gather(do_read(), do_close()) + assert vals == [b"Hello, World!", None] + + +@pytest.mark.onlynoncluster +def test_create_single_connection_client_from_url(): + client = Redis.from_url("redis://localhost:6379/0?", single_connection_client=True) + assert client.single_connection_client is True + + +@pytest.mark.parametrize("from_url", (True, False)) +async def test_pool_auto_close(request, from_url): + """Verify that basic Redis instances have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + async def get_redis_connection(): + if from_url: + return Redis.from_url(url) + return Redis(**url_args) + + r1 = await get_redis_connection() + assert r1.auto_close_connection_pool is True + await r1.close() + + +@pytest.mark.parametrize("from_url", (True, False)) +async def test_pool_auto_close_disable(request, from_url): + """Verify that auto_close_connection_pool can be disabled""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + async def get_redis_connection(): + if from_url: + return Redis.from_url(url, auto_close_connection_pool=False) + url_args["auto_close_connection_pool"] = False + return Redis(**url_args) + + r1 = await get_redis_connection() + assert r1.auto_close_connection_pool is False + await r1.connection_pool.disconnect() + await r1.close() diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index d1e52bd2a3..7672dc74b4 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -4,7 +4,6 @@ import pytest import pytest_asyncio - import redis.asyncio as redis from redis.asyncio.connection import Connection, to_bool from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt @@ -136,14 +135,14 @@ async def test_connection_creation(self): assert connection.kwargs == connection_kwargs async def test_multiple_connections(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") c2 = await pool.get_connection("_") assert c1 != c2 async def test_max_connections(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool( max_connections=2, connection_kwargs=connection_kwargs ) as pool: @@ -153,7 +152,7 @@ async def test_max_connections(self, master_host): await pool.get_connection("_") async def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") await pool.release(c1) @@ -237,7 +236,7 @@ async def test_multiple_connections(self, master_host): async def test_connection_pool_blocks_until_timeout(self, master_host): """When out of connections, block for timeout seconds, then raise""" - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) as pool: @@ -246,8 +245,9 @@ async def test_connection_pool_blocks_until_timeout(self, master_host): start = asyncio.get_running_loop().time() with pytest.raises(redis.ConnectionError): await pool.get_connection("_") - # we should have waited at least 0.1 seconds - assert asyncio.get_running_loop().time() - start >= 0.1 + + # we should have waited at least some period of time + assert asyncio.get_running_loop().time() - start >= 0.05 await c1.disconnect() async def test_connection_pool_blocks_until_conn_available(self, master_host): @@ -267,10 +267,11 @@ async def target(): start = asyncio.get_running_loop().time() await asyncio.gather(target(), pool.get_connection("_")) - assert asyncio.get_running_loop().time() - start >= 0.1 + stop = asyncio.get_running_loop().time() + assert (stop - start) <= 0.2 async def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") await pool.release(c1) @@ -606,10 +607,18 @@ async def test_busy_loading_from_pipeline(self, r): @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() async def test_read_only_error(self, r): - """READONLY errors get turned in ReadOnlyError exceptions""" + """READONLY errors get turned into ReadOnlyError exceptions""" with pytest.raises(redis.ReadOnlyError): await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + @skip_if_redis_enterprise() + async def test_oom_error(self, r): + """OOM errors get turned into OutOfMemoryError exceptions""" + with pytest.raises(redis.OutOfMemoryError): + # note: don't use the DEBUG OOM command since it's not the same + # as the db being full + await r.execute_command("DEBUG", "ERROR", "OOM blah blah") + def test_connect_from_url_tcp(self): connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool @@ -658,6 +667,7 @@ async def r(self, create_redis, server): @pytest.mark.onlynoncluster +@pytest.mark.xfail(strict=False) class TestHealthCheck: interval = 60 diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 8e213cdb26..4429f7453b 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -5,7 +5,6 @@ import pytest import pytest_asyncio - import redis from redis import AuthenticationError, DataError, ResponseError from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py new file mode 100644 index 0000000000..ff588861e4 --- /dev/null +++ b/tests/test_asyncio/test_cwe_404.py @@ -0,0 +1,250 @@ +import asyncio +import contextlib + +import pytest +from redis.asyncio import Redis +from redis.asyncio.cluster import RedisCluster +from redis.asyncio.connection import async_timeout + + +class DelayProxy: + def __init__(self, addr, redis_addr, delay: float = 0.0): + self.addr = addr + self.redis_addr = redis_addr + self.delay = delay + self.send_event = asyncio.Event() + self.server = None + self.task = None + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, *args): + await self.stop() + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + @contextlib.contextmanager + def set_delay(self, delay: float = 0.0): + """ + Allow to override the delay for parts of tests which aren't time dependent, + to speed up execution. + """ + old_delay = self.delay + self.delay = delay + try: + yield + finally: + self.delay = old_delay + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + pipe1 = asyncio.create_task( + self.pipe(reader, redis_writer, "to redis:", self.send_event) + ) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:")) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def stop(self): + # clean up enough so that we can reuse the looper + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + loop = self.server.get_loop() + await loop.shutdown_asyncgens() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + name="", + event: asyncio.Event = None, + ): + while True: + data = await reader.read(1000) + if not data: + break + # print(f"{name} read {len(data)} delay {self.delay}") + if event: + event.set() + await asyncio.sleep(self.delay) + writer.write(data) + await writer.drain() + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone(delay, master_host): + + # create a tcp socket proxy that relays data to Redis and back, + # inserting 0.1 seconds of delay + async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: + + for b in [True, False]: + # note that we connect to proxy, rather than to Redis directly + async with Redis( + host="127.0.0.1", port=5380, single_connection_client=b + ) as r: + + await r.set("foo", "foo") + await r.set("bar", "bar") + + async def op(r): + with dp.set_delay(delay * 2): + return await r.get( + "foo" + ) # <-- this is the operation we want to cancel + + dp.send_event.clear() + t = asyncio.create_task(op(r)) + # Wait until the task has sent, and then some, to make sure it has + # settled on the read. + await dp.send_event.wait() + await asyncio.sleep(0.01) # a little extra time for prudence + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # make sure that our previous request, cancelled while waiting for + # a repsponse, didn't leave the connection open andin a bad state + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) +async def test_standalone_pipeline(delay, master_host): + async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: + for b in [True, False]: + async with Redis( + host="127.0.0.1", port=5380, single_connection_client=b + ) as r: + await r.set("foo", "foo") + await r.set("bar", "bar") + + pipe = r.pipeline() + + pipe2 = r.pipeline() + pipe2.get("bar") + pipe2.ping() + pipe2.get("foo") + + async def op(pipe): + with dp.set_delay(delay * 2): + return await pipe.get( + "foo" + ).execute() # <-- this is the operation we want to cancel + + dp.send_event.clear() + t = asyncio.create_task(op(pipe)) + # wait until task has settled on the read + await dp.send_event.wait() + await asyncio.sleep(0.01) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # we have now cancelled the pieline in the middle of a request, + # make sure that the connection is still usable + pipe.get("bar") + pipe.ping() + pipe.get("foo") + await pipe.reset() + + # check that the pipeline is empty after reset + assert await pipe.execute() == [] + + # validating that the pipeline can be used as it could previously + pipe.get("bar") + pipe.ping() + pipe.get("foo") + assert await pipe.execute() == [b"bar", True, b"foo"] + assert await pipe2.execute() == [b"bar", True, b"foo"] + + +@pytest.mark.onlycluster +async def test_cluster(master_host): + + delay = 0.1 + cluster_port = 16379 + remap_base = 7372 + n_nodes = 6 + hostname, _ = master_host + + def remap(address): + host, port = address + return host, remap_base + port - cluster_port + + proxies = [] + for i in range(n_nodes): + port = cluster_port + i + remapped = remap_base + i + forward_addr = hostname, port + proxy = DelayProxy(addr=("127.0.0.1", remapped), redis_addr=forward_addr) + proxies.append(proxy) + + def all_clear(): + for p in proxies: + p.send_event.clear() + + async def wait_for_send(): + asyncio.wait( + [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + ) + + @contextlib.contextmanager + def set_delay(delay: float): + with contextlib.ExitStack() as stack: + for p in proxies: + stack.enter_context(p.set_delay(delay)) + yield + + async with contextlib.AsyncExitStack() as stack: + for p in proxies: + await stack.enter_async_context(p) + + with contextlib.closing( + RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", address_remap=remap + ) + ) as r: + await r.initialize() + await r.set("foo", "foo") + await r.set("bar", "bar") + + async def op(r): + with set_delay(delay): + return await r.get("foo") + + all_clear() + t = asyncio.create_task(op(r)) + # Wait for whichever DelayProxy gets the request first + await wait_for_send() + await asyncio.sleep(0.01) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # try a number of requests to excercise all the connections + async def doit(): + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + await asyncio.gather(*[doit() for _ in range(10)]) diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index 3efcf69e5b..162ccb367d 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -1,6 +1,5 @@ import pytest import pytest_asyncio - import redis.asyncio as redis from redis.exceptions import DataError @@ -90,6 +89,7 @@ async def r(self, create_redis): yield redis await redis.flushall() + @pytest.mark.xfail async def test_basic_command(self, r: redis.Redis): await r.set("hello", "world") diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py index 7e70baae89..22195901e6 100644 --- a/tests/test_asyncio/test_graph.py +++ b/tests/test_asyncio/test_graph.py @@ -1,5 +1,4 @@ import pytest - import redis.asyncio as redis from redis.commands.graph import Edge, Node, Path from redis.commands.graph.execution_plan import Operation @@ -8,15 +7,15 @@ @pytest.mark.redismod -async def test_bulk(modclient): +async def test_bulk(decoded_r): with pytest.raises(NotImplementedError): - await modclient.graph().bulk() - await modclient.graph().bulk(foo="bar!") + await decoded_r.graph().bulk() + await decoded_r.graph().bulk(foo="bar!") @pytest.mark.redismod -async def test_graph_creation(modclient: redis.Redis): - graph = modclient.graph() +async def test_graph_creation(decoded_r: redis.Redis): + graph = decoded_r.graph() john = Node( label="person", @@ -60,8 +59,8 @@ async def test_graph_creation(modclient: redis.Redis): @pytest.mark.redismod -async def test_array_functions(modclient: redis.Redis): - graph = modclient.graph() +async def test_array_functions(decoded_r: redis.Redis): + graph = decoded_r.graph() query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" await graph.query(query) @@ -83,12 +82,12 @@ async def test_array_functions(modclient: redis.Redis): @pytest.mark.redismod -async def test_path(modclient: redis.Redis): +async def test_path(decoded_r: redis.Redis): node0 = Node(node_id=0, label="L1") node1 = Node(node_id=1, label="L1") edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) - graph = modclient.graph() + graph = decoded_r.graph() graph.add_node(node0) graph.add_node(node1) graph.add_edge(edge01) @@ -103,20 +102,20 @@ async def test_path(modclient: redis.Redis): @pytest.mark.redismod -async def test_param(modclient: redis.Redis): +async def test_param(decoded_r: redis.Redis): params = [1, 2.3, "str", True, False, None, [0, 1, 2]] query = "RETURN $param" for param in params: - result = await modclient.graph().query(query, {"param": param}) + result = await decoded_r.graph().query(query, {"param": param}) expected_results = [[param]] assert expected_results == result.result_set @pytest.mark.redismod -async def test_map(modclient: redis.Redis): +async def test_map(decoded_r: redis.Redis): query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" - actual = (await modclient.graph().query(query)).result_set[0][0] + actual = (await decoded_r.graph().query(query)).result_set[0][0] expected = { "a": 1, "b": "str", @@ -130,40 +129,40 @@ async def test_map(modclient: redis.Redis): @pytest.mark.redismod -async def test_point(modclient: redis.Redis): +async def test_point(decoded_r: redis.Redis): query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" expected_lat = 32.070794860 expected_lon = 34.820751118 - actual = (await modclient.graph().query(query)).result_set[0][0] + actual = (await decoded_r.graph().query(query)).result_set[0][0] assert abs(actual["latitude"] - expected_lat) < 0.001 assert abs(actual["longitude"] - expected_lon) < 0.001 query = "RETURN point({latitude: 32, longitude: 34.0})" expected_lat = 32 expected_lon = 34 - actual = (await modclient.graph().query(query)).result_set[0][0] + actual = (await decoded_r.graph().query(query)).result_set[0][0] assert abs(actual["latitude"] - expected_lat) < 0.001 assert abs(actual["longitude"] - expected_lon) < 0.001 @pytest.mark.redismod -async def test_index_response(modclient: redis.Redis): - result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") +async def test_index_response(decoded_r: redis.Redis): + result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") assert 1 == result_set.indices_created - result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") + result_set = await decoded_r.graph().query("CREATE INDEX ON :person(age)") assert 0 == result_set.indices_created - result_set = await modclient.graph().query("DROP INDEX ON :person(age)") + result_set = await decoded_r.graph().query("DROP INDEX ON :person(age)") assert 1 == result_set.indices_deleted with pytest.raises(ResponseError): - await modclient.graph().query("DROP INDEX ON :person(age)") + await decoded_r.graph().query("DROP INDEX ON :person(age)") @pytest.mark.redismod -async def test_stringify_query_result(modclient: redis.Redis): - graph = modclient.graph() +async def test_stringify_query_result(decoded_r: redis.Redis): + graph = decoded_r.graph() john = Node( alias="a", @@ -216,14 +215,14 @@ async def test_stringify_query_result(modclient: redis.Redis): @pytest.mark.redismod -async def test_optional_match(modclient: redis.Redis): +async def test_optional_match(decoded_r: redis.Redis): # Build a graph of form (a)-[R]->(b) node0 = Node(node_id=0, label="L1", properties={"value": "a"}) node1 = Node(node_id=1, label="L1", properties={"value": "b"}) edge01 = Edge(node0, "R", node1, edge_id=0) - graph = modclient.graph() + graph = decoded_r.graph() graph.add_node(node0) graph.add_node(node1) graph.add_edge(edge01) @@ -241,17 +240,17 @@ async def test_optional_match(modclient: redis.Redis): @pytest.mark.redismod -async def test_cached_execution(modclient: redis.Redis): - await modclient.graph().query("CREATE ()") +async def test_cached_execution(decoded_r: redis.Redis): + await decoded_r.graph().query("CREATE ()") - uncached_result = await modclient.graph().query( + uncached_result = await decoded_r.graph().query( "MATCH (n) RETURN n, $param", {"param": [0]} ) assert uncached_result.cached_execution is False # loop to make sure the query is cached on each thread on server for x in range(0, 64): - cached_result = await modclient.graph().query( + cached_result = await decoded_r.graph().query( "MATCH (n) RETURN n, $param", {"param": [0]} ) assert uncached_result.result_set == cached_result.result_set @@ -261,50 +260,51 @@ async def test_cached_execution(modclient: redis.Redis): @pytest.mark.redismod -async def test_slowlog(modclient: redis.Redis): +async def test_slowlog(decoded_r: redis.Redis): create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" - await modclient.graph().query(create_query) + await decoded_r.graph().query(create_query) - results = await modclient.graph().slowlog() + results = await decoded_r.graph().slowlog() assert results[0][1] == "GRAPH.QUERY" assert results[0][2] == create_query @pytest.mark.redismod -async def test_query_timeout(modclient: redis.Redis): +@pytest.mark.xfail(strict=False) +async def test_query_timeout(decoded_r: redis.Redis): # Build a sample graph with 1000 nodes. - await modclient.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") + await decoded_r.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") # Issue a long-running query with a 1-millisecond timeout. with pytest.raises(ResponseError): - await modclient.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) + await decoded_r.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) assert False is False with pytest.raises(Exception): - await modclient.graph().query("RETURN 1", timeout="str") + await decoded_r.graph().query("RETURN 1", timeout="str") assert False is False @pytest.mark.redismod -async def test_read_only_query(modclient: redis.Redis): +async def test_read_only_query(decoded_r: redis.Redis): with pytest.raises(Exception): # Issue a write query, specifying read-only true, # this call should fail. - await modclient.graph().query("CREATE (p:person {name:'a'})", read_only=True) + await decoded_r.graph().query("CREATE (p:person {name:'a'})", read_only=True) assert False is False @pytest.mark.redismod -async def test_profile(modclient: redis.Redis): +async def test_profile(decoded_r: redis.Redis): q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" - profile = (await modclient.graph().profile(q)).result_set + profile = (await decoded_r.graph().profile(q)).result_set assert "Create | Records produced: 3" in profile assert "Unwind | Records produced: 3" in profile q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" - profile = (await modclient.graph().profile(q)).result_set + profile = (await decoded_r.graph().profile(q)).result_set assert "Results | Records produced: 2" in profile assert "Project | Records produced: 2" in profile assert "Filter | Records produced: 2" in profile @@ -313,16 +313,16 @@ async def test_profile(modclient: redis.Redis): @pytest.mark.redismod @skip_if_redis_enterprise() -async def test_config(modclient: redis.Redis): +async def test_config(decoded_r: redis.Redis): config_name = "RESULTSET_SIZE" config_value = 3 # Set configuration - response = await modclient.graph().config(config_name, config_value, set=True) + response = await decoded_r.graph().config(config_name, config_value, set=True) assert response == "OK" # Make sure config been updated. - response = await modclient.graph().config(config_name, set=False) + response = await decoded_r.graph().config(config_name, set=False) expected_response = [config_name, config_value] assert response == expected_response @@ -330,46 +330,46 @@ async def test_config(modclient: redis.Redis): config_value = 1 << 20 # 1MB # Set configuration - response = await modclient.graph().config(config_name, config_value, set=True) + response = await decoded_r.graph().config(config_name, config_value, set=True) assert response == "OK" # Make sure config been updated. - response = await modclient.graph().config(config_name, set=False) + response = await decoded_r.graph().config(config_name, set=False) expected_response = [config_name, config_value] assert response == expected_response # reset to default - await modclient.graph().config("QUERY_MEM_CAPACITY", 0, set=True) - await modclient.graph().config("RESULTSET_SIZE", -100, set=True) + await decoded_r.graph().config("QUERY_MEM_CAPACITY", 0, set=True) + await decoded_r.graph().config("RESULTSET_SIZE", -100, set=True) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_list_keys(modclient: redis.Redis): - result = await modclient.graph().list_keys() +async def test_list_keys(decoded_r: redis.Redis): + result = await decoded_r.graph().list_keys() assert result == [] - await modclient.graph("G").query("CREATE (n)") - result = await modclient.graph().list_keys() + await decoded_r.graph("G").query("CREATE (n)") + result = await decoded_r.graph().list_keys() assert result == ["G"] - await modclient.graph("X").query("CREATE (m)") - result = await modclient.graph().list_keys() + await decoded_r.graph("X").query("CREATE (m)") + result = await decoded_r.graph().list_keys() assert result == ["G", "X"] - await modclient.delete("G") - await modclient.rename("X", "Z") - result = await modclient.graph().list_keys() + await decoded_r.delete("G") + await decoded_r.rename("X", "Z") + result = await decoded_r.graph().list_keys() assert result == ["Z"] - await modclient.delete("Z") - result = await modclient.graph().list_keys() + await decoded_r.delete("Z") + result = await decoded_r.graph().list_keys() assert result == [] @pytest.mark.redismod -async def test_multi_label(modclient: redis.Redis): - redis_graph = modclient.graph("g") +async def test_multi_label(decoded_r: redis.Redis): + redis_graph = decoded_r.graph("g") node = Node(label=["l", "ll"]) redis_graph.add_node(node) @@ -394,8 +394,8 @@ async def test_multi_label(modclient: redis.Redis): @pytest.mark.redismod -async def test_execution_plan(modclient: redis.Redis): - redis_graph = modclient.graph("execution_plan") +async def test_execution_plan(decoded_r: redis.Redis): + redis_graph = decoded_r.graph("execution_plan") create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), @@ -413,8 +413,8 @@ async def test_execution_plan(modclient: redis.Redis): @pytest.mark.redismod -async def test_explain(modclient: redis.Redis): - redis_graph = modclient.graph("execution_plan") +async def test_explain(decoded_r: redis.Redis): + redis_graph = decoded_r.graph("execution_plan") # graph creation / population create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index b8854d20cd..ed651cd903 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -1,266 +1,339 @@ import pytest - import redis.asyncio as redis from redis import exceptions from redis.commands.json.path import Path -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import assert_resp_response, skip_ifmodversion_lt @pytest.mark.redismod -async def test_json_setbinarykey(modclient: redis.Redis): +async def test_json_setbinarykey(decoded_r: redis.Redis): d = {"hello": "world", b"some": "value"} with pytest.raises(TypeError): - modclient.json().set("somekey", Path.root_path(), d) - assert await modclient.json().set("somekey", Path.root_path(), d, decode_keys=True) + decoded_r.json().set("somekey", Path.root_path(), d) + assert await decoded_r.json().set("somekey", Path.root_path(), d, decode_keys=True) @pytest.mark.redismod -async def test_json_setgetdeleteforget(modclient: redis.Redis): - assert await modclient.json().set("foo", Path.root_path(), "bar") - assert await modclient.json().get("foo") == "bar" - assert await modclient.json().get("baz") is None - assert await modclient.json().delete("foo") == 1 - assert await modclient.json().forget("foo") == 0 # second delete - assert await modclient.exists("foo") == 0 +async def test_json_setgetdeleteforget(decoded_r: redis.Redis): + assert await decoded_r.json().set("foo", Path.root_path(), "bar") + assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) + assert await decoded_r.json().get("baz") is None + assert await decoded_r.json().delete("foo") == 1 + assert await decoded_r.json().forget("foo") == 0 # second delete + assert await decoded_r.exists("foo") == 0 @pytest.mark.redismod -async def test_jsonget(modclient: redis.Redis): - await modclient.json().set("foo", Path.root_path(), "bar") - assert await modclient.json().get("foo") == "bar" +async def test_jsonget(decoded_r: redis.Redis): + await decoded_r.json().set("foo", Path.root_path(), "bar") + assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) @pytest.mark.redismod -async def test_json_get_jset(modclient: redis.Redis): - assert await modclient.json().set("foo", Path.root_path(), "bar") - assert "bar" == await modclient.json().get("foo") - assert await modclient.json().get("baz") is None - assert 1 == await modclient.json().delete("foo") - assert await modclient.exists("foo") == 0 +async def test_json_get_jset(decoded_r: redis.Redis): + assert await decoded_r.json().set("foo", Path.root_path(), "bar") + assert_resp_response(decoded_r, await decoded_r.json().get("foo"), "bar", [["bar"]]) + assert await decoded_r.json().get("baz") is None + assert 1 == await decoded_r.json().delete("foo") + assert await decoded_r.exists("foo") == 0 + + +@pytest.mark.redismod +async def test_nonascii_setgetdelete(decoded_r: redis.Redis): + assert await decoded_r.json().set("notascii", Path.root_path(), "hyvää-élève") + res = "hyvää-élève" + assert_resp_response( + decoded_r, await decoded_r.json().get("notascii", no_escape=True), res, [[res]] + ) + assert 1 == await decoded_r.json().delete("notascii") + assert await decoded_r.exists("notascii") == 0 @pytest.mark.redismod -async def test_nonascii_setgetdelete(modclient: redis.Redis): - assert await modclient.json().set("notascii", Path.root_path(), "hyvää-élève") - assert "hyvää-élève" == await modclient.json().get("notascii", no_escape=True) - assert 1 == await modclient.json().delete("notascii") - assert await modclient.exists("notascii") == 0 +@skip_ifmodversion_lt("2.6.0", "ReJSON") +async def test_json_merge(decoded_r: redis.Redis): + # Test with root path $ + assert await decoded_r.json().set( + "person_data", + "$", + {"person1": {"personal_data": {"name": "John"}}}, + ) + assert await decoded_r.json().merge( + "person_data", "$", {"person1": {"personal_data": {"hobbies": "reading"}}} + ) + assert await decoded_r.json().get("person_data") == { + "person1": {"personal_data": {"name": "John", "hobbies": "reading"}} + } + + # Test with root path path $.person1.personal_data + assert await decoded_r.json().merge( + "person_data", "$.person1.personal_data", {"country": "Israel"} + ) + assert await decoded_r.json().get("person_data") == { + "person1": { + "personal_data": {"name": "John", "hobbies": "reading", "country": "Israel"} + } + } + + # Test with null value to delete a value + assert await decoded_r.json().merge( + "person_data", "$.person1.personal_data", {"name": None} + ) + assert await decoded_r.json().get("person_data") == { + "person1": {"personal_data": {"country": "Israel", "hobbies": "reading"}} + } @pytest.mark.redismod -async def test_jsonsetexistentialmodifiersshouldsucceed(modclient: redis.Redis): +async def test_jsonsetexistentialmodifiersshouldsucceed(decoded_r: redis.Redis): obj = {"foo": "bar"} - assert await modclient.json().set("obj", Path.root_path(), obj) + assert await decoded_r.json().set("obj", Path.root_path(), obj) # Test that flags prevent updates when conditions are unmet - assert await modclient.json().set("obj", Path("foo"), "baz", nx=True) is None - assert await modclient.json().set("obj", Path("qaz"), "baz", xx=True) is None + assert await decoded_r.json().set("obj", Path("foo"), "baz", nx=True) is None + assert await decoded_r.json().set("obj", Path("qaz"), "baz", xx=True) is None # Test that flags allow updates when conditions are met - assert await modclient.json().set("obj", Path("foo"), "baz", xx=True) - assert await modclient.json().set("obj", Path("qaz"), "baz", nx=True) + assert await decoded_r.json().set("obj", Path("foo"), "baz", xx=True) + assert await decoded_r.json().set("obj", Path("qaz"), "baz", nx=True) # Test that flags are mutually exlusive with pytest.raises(Exception): - await modclient.json().set("obj", Path("foo"), "baz", nx=True, xx=True) + await decoded_r.json().set("obj", Path("foo"), "baz", nx=True, xx=True) + + +@pytest.mark.redismod +async def test_mgetshouldsucceed(decoded_r: redis.Redis): + await decoded_r.json().set("1", Path.root_path(), 1) + await decoded_r.json().set("2", Path.root_path(), 2) + assert await decoded_r.json().mget(["1"], Path.root_path()) == [1] + + assert await decoded_r.json().mget([1, 2], Path.root_path()) == [1, 2] @pytest.mark.redismod -async def test_mgetshouldsucceed(modclient: redis.Redis): - await modclient.json().set("1", Path.root_path(), 1) - await modclient.json().set("2", Path.root_path(), 2) - assert await modclient.json().mget(["1"], Path.root_path()) == [1] +@pytest.mark.onlynoncluster +@skip_ifmodversion_lt("2.6.0", "ReJSON") +async def test_mset(decoded_r: redis.Redis): + await decoded_r.json().mset( + [("1", Path.root_path(), 1), ("2", Path.root_path(), 2)] + ) - assert await modclient.json().mget([1, 2], Path.root_path()) == [1, 2] + assert await decoded_r.json().mget(["1"], Path.root_path()) == [1] + assert await decoded_r.json().mget(["1", "2"], Path.root_path()) == [1, 2] @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -async def test_clear(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 1 == await modclient.json().clear("arr", Path.root_path()) - assert [] == await modclient.json().get("arr") +async def test_clear(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 1 == await decoded_r.json().clear("arr", Path.root_path()) + assert_resp_response(decoded_r, await decoded_r.json().get("arr"), [], [[[]]]) @pytest.mark.redismod -async def test_type(modclient: redis.Redis): - await modclient.json().set("1", Path.root_path(), 1) - assert "integer" == await modclient.json().type("1", Path.root_path()) - assert "integer" == await modclient.json().type("1") +async def test_type(decoded_r: redis.Redis): + await decoded_r.json().set("1", Path.root_path(), 1) + assert_resp_response( + decoded_r, + await decoded_r.json().type("1", Path.root_path()), + "integer", + ["integer"], + ) + assert_resp_response( + decoded_r, await decoded_r.json().type("1"), "integer", ["integer"] + ) @pytest.mark.redismod -async def test_numincrby(modclient): - await modclient.json().set("num", Path.root_path(), 1) - assert 2 == await modclient.json().numincrby("num", Path.root_path(), 1) - assert 2.5 == await modclient.json().numincrby("num", Path.root_path(), 0.5) - assert 1.25 == await modclient.json().numincrby("num", Path.root_path(), -1.25) +async def test_numincrby(decoded_r): + await decoded_r.json().set("num", Path.root_path(), 1) + assert_resp_response( + decoded_r, await decoded_r.json().numincrby("num", Path.root_path(), 1), 2, [2] + ) + res = await decoded_r.json().numincrby("num", Path.root_path(), 0.5) + assert_resp_response(decoded_r, res, 2.5, [2.5]) + res = await decoded_r.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response(decoded_r, res, 1.25, [1.25]) @pytest.mark.redismod -async def test_nummultby(modclient: redis.Redis): - await modclient.json().set("num", Path.root_path(), 1) +async def test_nummultby(decoded_r: redis.Redis): + await decoded_r.json().set("num", Path.root_path(), 1) with pytest.deprecated_call(): - assert 2 == await modclient.json().nummultby("num", Path.root_path(), 2) - assert 5 == await modclient.json().nummultby("num", Path.root_path(), 2.5) - assert 2.5 == await modclient.json().nummultby("num", Path.root_path(), 0.5) + res = await decoded_r.json().nummultby("num", Path.root_path(), 2) + assert_resp_response(decoded_r, res, 2, [2]) + res = await decoded_r.json().nummultby("num", Path.root_path(), 2.5) + assert_resp_response(decoded_r, res, 5, [5]) + res = await decoded_r.json().nummultby("num", Path.root_path(), 0.5) + assert_resp_response(decoded_r, res, 2.5, [2.5]) @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -async def test_toggle(modclient: redis.Redis): - await modclient.json().set("bool", Path.root_path(), False) - assert await modclient.json().toggle("bool", Path.root_path()) - assert await modclient.json().toggle("bool", Path.root_path()) is False +async def test_toggle(decoded_r: redis.Redis): + await decoded_r.json().set("bool", Path.root_path(), False) + assert await decoded_r.json().toggle("bool", Path.root_path()) + assert await decoded_r.json().toggle("bool", Path.root_path()) is False # check non-boolean value - await modclient.json().set("num", Path.root_path(), 1) + await decoded_r.json().set("num", Path.root_path(), 1) with pytest.raises(exceptions.ResponseError): - await modclient.json().toggle("num", Path.root_path()) + await decoded_r.json().toggle("num", Path.root_path()) @pytest.mark.redismod -async def test_strappend(modclient: redis.Redis): - await modclient.json().set("jsonkey", Path.root_path(), "foo") - assert 6 == await modclient.json().strappend("jsonkey", "bar") - assert "foobar" == await modclient.json().get("jsonkey", Path.root_path()) +async def test_strappend(decoded_r: redis.Redis): + await decoded_r.json().set("jsonkey", Path.root_path(), "foo") + assert 6 == await decoded_r.json().strappend("jsonkey", "bar") + res = await decoded_r.json().get("jsonkey", Path.root_path()) + assert_resp_response(decoded_r, res, "foobar", [["foobar"]]) @pytest.mark.redismod -async def test_strlen(modclient: redis.Redis): - await modclient.json().set("str", Path.root_path(), "foo") - assert 3 == await modclient.json().strlen("str", Path.root_path()) - await modclient.json().strappend("str", "bar", Path.root_path()) - assert 6 == await modclient.json().strlen("str", Path.root_path()) - assert 6 == await modclient.json().strlen("str") +async def test_strlen(decoded_r: redis.Redis): + await decoded_r.json().set("str", Path.root_path(), "foo") + assert 3 == await decoded_r.json().strlen("str", Path.root_path()) + await decoded_r.json().strappend("str", "bar", Path.root_path()) + assert 6 == await decoded_r.json().strlen("str", Path.root_path()) + assert 6 == await decoded_r.json().strlen("str") @pytest.mark.redismod -async def test_arrappend(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [1]) - assert 2 == await modclient.json().arrappend("arr", Path.root_path(), 2) - assert 4 == await modclient.json().arrappend("arr", Path.root_path(), 3, 4) - assert 7 == await modclient.json().arrappend("arr", Path.root_path(), *[5, 6, 7]) +async def test_arrappend(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [1]) + assert 2 == await decoded_r.json().arrappend("arr", Path.root_path(), 2) + assert 4 == await decoded_r.json().arrappend("arr", Path.root_path(), 3, 4) + assert 7 == await decoded_r.json().arrappend("arr", Path.root_path(), *[5, 6, 7]) @pytest.mark.redismod -async def test_arrindex(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 1 == await modclient.json().arrindex("arr", Path.root_path(), 1) - assert -1 == await modclient.json().arrindex("arr", Path.root_path(), 1, 2) +async def test_arrindex(decoded_r: redis.Redis): + r_path = Path.root_path() + await decoded_r.json().set("arr", r_path, [0, 1, 2, 3, 4]) + assert 1 == await decoded_r.json().arrindex("arr", r_path, 1) + assert -1 == await decoded_r.json().arrindex("arr", r_path, 1, 2) + assert 4 == await decoded_r.json().arrindex("arr", r_path, 4) + assert 4 == await decoded_r.json().arrindex("arr", r_path, 4, start=0) + assert 4 == await decoded_r.json().arrindex("arr", r_path, 4, start=0, stop=5000) + assert -1 == await decoded_r.json().arrindex("arr", r_path, 4, start=0, stop=-1) + assert -1 == await decoded_r.json().arrindex("arr", r_path, 4, start=1, stop=3) @pytest.mark.redismod -async def test_arrinsert(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 4]) - assert 5 - -await modclient.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) - assert [0, 1, 2, 3, 4] == await modclient.json().get("arr") +async def test_arrinsert(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 4]) + assert 5 == await decoded_r.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) + res = [0, 1, 2, 3, 4] + assert_resp_response(decoded_r, await decoded_r.json().get("arr"), res, [[res]]) # test prepends - await modclient.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) - await modclient.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) - assert await modclient.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + await decoded_r.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) + await decoded_r.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) + res = [["some", "thing"], 5, 6, 7, 8, 9] + assert_resp_response(decoded_r, await decoded_r.json().get("val2"), res, [[res]]) @pytest.mark.redismod -async def test_arrlen(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 5 == await modclient.json().arrlen("arr", Path.root_path()) - assert 5 == await modclient.json().arrlen("arr") - assert await modclient.json().arrlen("fakekey") is None +async def test_arrlen(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 5 == await decoded_r.json().arrlen("arr", Path.root_path()) + assert 5 == await decoded_r.json().arrlen("arr") + assert await decoded_r.json().arrlen("fakekey") is None @pytest.mark.redismod -async def test_arrpop(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 4 == await modclient.json().arrpop("arr", Path.root_path(), 4) - assert 3 == await modclient.json().arrpop("arr", Path.root_path(), -1) - assert 2 == await modclient.json().arrpop("arr", Path.root_path()) - assert 0 == await modclient.json().arrpop("arr", Path.root_path(), 0) - assert [1] == await modclient.json().get("arr") +async def test_arrpop(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 4 == await decoded_r.json().arrpop("arr", Path.root_path(), 4) + assert 3 == await decoded_r.json().arrpop("arr", Path.root_path(), -1) + assert 2 == await decoded_r.json().arrpop("arr", Path.root_path()) + assert 0 == await decoded_r.json().arrpop("arr", Path.root_path(), 0) + assert_resp_response(decoded_r, await decoded_r.json().get("arr"), [1], [[[1]]]) # test out of bounds - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 4 == await modclient.json().arrpop("arr", Path.root_path(), 99) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 4 == await decoded_r.json().arrpop("arr", Path.root_path(), 99) # none test - await modclient.json().set("arr", Path.root_path(), []) - assert await modclient.json().arrpop("arr") is None + await decoded_r.json().set("arr", Path.root_path(), []) + assert await decoded_r.json().arrpop("arr") is None @pytest.mark.redismod -async def test_arrtrim(modclient: redis.Redis): - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 3 == await modclient.json().arrtrim("arr", Path.root_path(), 1, 3) - assert [1, 2, 3] == await modclient.json().get("arr") +async def test_arrtrim(decoded_r: redis.Redis): + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 3 == await decoded_r.json().arrtrim("arr", Path.root_path(), 1, 3) + res = await decoded_r.json().get("arr") + assert_resp_response(decoded_r, res, [1, 2, 3], [[[1, 2, 3]]]) # <0 test, should be 0 equivalent - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 0 == await modclient.json().arrtrim("arr", Path.root_path(), -1, 3) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), -1, 3) # testing stop > end - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 2 == await modclient.json().arrtrim("arr", Path.root_path(), 3, 99) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 2 == await decoded_r.json().arrtrim("arr", Path.root_path(), 3, 99) # start > array size and stop - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 0 == await modclient.json().arrtrim("arr", Path.root_path(), 9, 1) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), 9, 1) # all larger - await modclient.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) - assert 0 == await modclient.json().arrtrim("arr", Path.root_path(), 9, 11) + await decoded_r.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) + assert 0 == await decoded_r.json().arrtrim("arr", Path.root_path(), 9, 11) @pytest.mark.redismod -async def test_resp(modclient: redis.Redis): +async def test_resp(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": 1, "qaz": True} - await modclient.json().set("obj", Path.root_path(), obj) - assert "bar" == await modclient.json().resp("obj", Path("foo")) - assert 1 == await modclient.json().resp("obj", Path("baz")) - assert await modclient.json().resp("obj", Path("qaz")) - assert isinstance(await modclient.json().resp("obj"), list) + await decoded_r.json().set("obj", Path.root_path(), obj) + assert "bar" == await decoded_r.json().resp("obj", Path("foo")) + assert 1 == await decoded_r.json().resp("obj", Path("baz")) + assert await decoded_r.json().resp("obj", Path("qaz")) + assert isinstance(await decoded_r.json().resp("obj"), list) @pytest.mark.redismod -async def test_objkeys(modclient: redis.Redis): +async def test_objkeys(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": "qaz"} - await modclient.json().set("obj", Path.root_path(), obj) - keys = await modclient.json().objkeys("obj", Path.root_path()) + await decoded_r.json().set("obj", Path.root_path(), obj) + keys = await decoded_r.json().objkeys("obj", Path.root_path()) keys.sort() exp = list(obj.keys()) exp.sort() assert exp == keys - await modclient.json().set("obj", Path.root_path(), obj) - keys = await modclient.json().objkeys("obj") + await decoded_r.json().set("obj", Path.root_path(), obj) + keys = await decoded_r.json().objkeys("obj") assert keys == list(obj.keys()) - assert await modclient.json().objkeys("fakekey") is None + assert await decoded_r.json().objkeys("fakekey") is None @pytest.mark.redismod -async def test_objlen(modclient: redis.Redis): +async def test_objlen(decoded_r: redis.Redis): obj = {"foo": "bar", "baz": "qaz"} - await modclient.json().set("obj", Path.root_path(), obj) - assert len(obj) == await modclient.json().objlen("obj", Path.root_path()) + await decoded_r.json().set("obj", Path.root_path(), obj) + assert len(obj) == await decoded_r.json().objlen("obj", Path.root_path()) - await modclient.json().set("obj", Path.root_path(), obj) - assert len(obj) == await modclient.json().objlen("obj") + await decoded_r.json().set("obj", Path.root_path(), obj) + assert len(obj) == await decoded_r.json().objlen("obj") # @pytest.mark.redismod -# async def test_json_commands_in_pipeline(modclient: redis.Redis): -# async with modclient.json().pipeline() as p: +# async def test_json_commands_in_pipeline(decoded_r: redis.Redis): +# async with decoded_r.json().pipeline() as p: # p.set("foo", Path.root_path(), "bar") # p.get("foo") # p.delete("foo") # assert [True, "bar", 1] == await p.execute() -# assert await modclient.keys() == [] -# assert await modclient.get("foo") is None +# assert await decoded_r.keys() == [] +# assert await decoded_r.get("foo") is None # # now with a true, json object -# await modclient.flushdb() -# p = await modclient.json().pipeline() +# await decoded_r.flushdb() +# p = await decoded_r.json().pipeline() # d = {"hello": "world", "oh": "snap"} # with pytest.deprecated_call(): # p.jsonset("foo", Path.root_path(), d) @@ -268,23 +341,24 @@ async def test_objlen(modclient: redis.Redis): # p.exists("notarealkey") # p.delete("foo") # assert [True, d, 0, 1] == p.execute() -# assert await modclient.keys() == [] -# assert await modclient.get("foo") is None +# assert await decoded_r.keys() == [] +# assert await decoded_r.get("foo") is None @pytest.mark.redismod -async def test_json_delete_with_dollar(modclient: redis.Redis): +async def test_json_delete_with_dollar(decoded_r: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} - assert await modclient.json().set("doc1", "$", doc1) - assert await modclient.json().delete("doc1", "$..a") == 2 - r = await modclient.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + assert await decoded_r.json().set("doc1", "$", doc1) + assert await decoded_r.json().delete("doc1", "$..a") == 2 + res = [{"nested": {"b": 3}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} - assert await modclient.json().set("doc2", "$", doc2) - assert await modclient.json().delete("doc2", "$..a") == 1 - res = await modclient.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert await decoded_r.json().set("doc2", "$", doc2) + assert await decoded_r.json().delete("doc2", "$..a") == 1 + res = await decoded_r.json().get("doc2", "$") + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -298,8 +372,8 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): ], } ] - assert await modclient.json().set("doc3", "$", doc3) - assert await modclient.json().delete("doc3", '$.[0]["nested"]..ciao') == 3 + assert await decoded_r.json().set("doc3", "$", doc3) + assert await decoded_r.json().delete("doc3", '$.[0]["nested"]..ciao') == 3 doc3val = [ [ @@ -315,29 +389,29 @@ async def test_json_delete_with_dollar(modclient: redis.Redis): } ] ] - res = await modclient.json().get("doc3", "$") - assert res == doc3val + res = await decoded_r.json().get("doc3", "$") + assert_resp_response(decoded_r, res, doc3val, [doc3val]) # Test async default path - assert await modclient.json().delete("doc3") == 1 - assert await modclient.json().get("doc3", "$") is None + assert await decoded_r.json().delete("doc3") == 1 + assert await decoded_r.json().get("doc3", "$") is None - await modclient.json().delete("not_a_document", "..a") + await decoded_r.json().delete("not_a_document", "..a") @pytest.mark.redismod -async def test_json_forget_with_dollar(modclient: redis.Redis): +async def test_json_forget_with_dollar(decoded_r: redis.Redis): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} - assert await modclient.json().set("doc1", "$", doc1) - assert await modclient.json().forget("doc1", "$..a") == 2 - r = await modclient.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + assert await decoded_r.json().set("doc1", "$", doc1) + assert await decoded_r.json().forget("doc1", "$..a") == 2 + res = [{"nested": {"b": 3}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} - assert await modclient.json().set("doc2", "$", doc2) - assert await modclient.json().forget("doc2", "$..a") == 1 - res = await modclient.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert await decoded_r.json().set("doc2", "$", doc2) + assert await decoded_r.json().forget("doc2", "$..a") == 1 + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -351,8 +425,8 @@ async def test_json_forget_with_dollar(modclient: redis.Redis): ], } ] - assert await modclient.json().set("doc3", "$", doc3) - assert await modclient.json().forget("doc3", '$.[0]["nested"]..ciao') == 3 + assert await decoded_r.json().set("doc3", "$", doc3) + assert await decoded_r.json().forget("doc3", '$.[0]["nested"]..ciao') == 3 doc3val = [ [ @@ -368,161 +442,165 @@ async def test_json_forget_with_dollar(modclient: redis.Redis): } ] ] - res = await modclient.json().get("doc3", "$") - assert res == doc3val + res = await decoded_r.json().get("doc3", "$") + assert_resp_response(decoded_r, res, doc3val, [doc3val]) # Test async default path - assert await modclient.json().forget("doc3") == 1 - assert await modclient.json().get("doc3", "$") is None + assert await decoded_r.json().forget("doc3") == 1 + assert await decoded_r.json().get("doc3", "$") is None - await modclient.json().forget("not_a_document", "..a") + await decoded_r.json().forget("not_a_document", "..a") @pytest.mark.redismod -async def test_json_mget_dollar(modclient: redis.Redis): +async def test_json_mget_dollar(decoded_r: redis.Redis): # Test mget with multi paths - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": 1, "b": 2, "nested": {"a": 3}, "c": None, "nested2": {"a": None}}, ) - await modclient.json().set( + await decoded_r.json().set( "doc2", "$", {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET - assert await modclient.json().get("doc1", "$..a") == [1, 3, None] - assert await modclient.json().get("doc2", "$..a") == [4, 6, [None]] + res = [1, 3, None] + assert_resp_response( + decoded_r, await decoded_r.json().get("doc1", "$..a"), res, [res] + ) + res = [4, 6, [None]] + assert_resp_response( + decoded_r, await decoded_r.json().get("doc2", "$..a"), res, [res] + ) # Test mget with single path - await modclient.json().mget("doc1", "$..a") == [1, 3, None] + await decoded_r.json().mget("doc1", "$..a") == [1, 3, None] # Test mget with multi path - res = await modclient.json().mget(["doc1", "doc2"], "$..a") + res = await decoded_r.json().mget(["doc1", "doc2"], "$..a") assert res == [[1, 3, None], [4, 6, [None]]] # Test missing key - res = await modclient.json().mget(["doc1", "missing_doc"], "$..a") + res = await decoded_r.json().mget(["doc1", "missing_doc"], "$..a") assert res == [[1, 3, None], None] - res = await modclient.json().mget(["missing_doc1", "missing_doc2"], "$..a") + res = await decoded_r.json().mget(["missing_doc1", "missing_doc2"], "$..a") assert res == [None, None] @pytest.mark.redismod -async def test_numby_commands_dollar(modclient: redis.Redis): +async def test_numby_commands_dollar(decoded_r: redis.Redis): # Test NUMINCRBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) # Test multi - assert await modclient.json().numincrby("doc1", "$..a", 2) == [None, 4, 7.0, None] + assert await decoded_r.json().numincrby("doc1", "$..a", 2) == [None, 4, 7.0, None] - res = await modclient.json().numincrby("doc1", "$..a", 2.5) + res = await decoded_r.json().numincrby("doc1", "$..a", 2.5) assert res == [None, 6.5, 9.5, None] # Test single - assert await modclient.json().numincrby("doc1", "$.b[1].a", 2) == [11.5] + assert await decoded_r.json().numincrby("doc1", "$.b[1].a", 2) == [11.5] - assert await modclient.json().numincrby("doc1", "$.b[2].a", 2) == [None] - assert await modclient.json().numincrby("doc1", "$.b[1].a", 3.5) == [15.0] + assert await decoded_r.json().numincrby("doc1", "$.b[2].a", 2) == [None] + assert await decoded_r.json().numincrby("doc1", "$.b[1].a", 3.5) == [15.0] # Test NUMMULTBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) # test list with pytest.deprecated_call(): - res = await modclient.json().nummultby("doc1", "$..a", 2) + res = await decoded_r.json().nummultby("doc1", "$..a", 2) assert res == [None, 4, 10, None] - res = await modclient.json().nummultby("doc1", "$..a", 2.5) + res = await decoded_r.json().nummultby("doc1", "$..a", 2.5) assert res == [None, 10.0, 25.0, None] # Test single with pytest.deprecated_call(): - assert await modclient.json().nummultby("doc1", "$.b[1].a", 2) == [50.0] - assert await modclient.json().nummultby("doc1", "$.b[2].a", 2) == [None] - assert await modclient.json().nummultby("doc1", "$.b[1].a", 3) == [150.0] + assert await decoded_r.json().nummultby("doc1", "$.b[1].a", 2) == [50.0] + assert await decoded_r.json().nummultby("doc1", "$.b[2].a", 2) == [None] + assert await decoded_r.json().nummultby("doc1", "$.b[1].a", 3) == [150.0] # test missing keys with pytest.raises(exceptions.ResponseError): - await modclient.json().numincrby("non_existing_doc", "$..a", 2) - await modclient.json().nummultby("non_existing_doc", "$..a", 2) + await decoded_r.json().numincrby("non_existing_doc", "$..a", 2) + await decoded_r.json().nummultby("non_existing_doc", "$..a", 2) # Test legacy NUMINCRBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) - await modclient.json().numincrby("doc1", ".b[0].a", 3) == 5 + await decoded_r.json().numincrby("doc1", ".b[0].a", 3) == 5 # Test legacy NUMMULTBY - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "b", "b": [{"a": 2}, {"a": 5.0}, {"a": "c"}]} ) with pytest.deprecated_call(): - await modclient.json().nummultby("doc1", ".b[0].a", 3) == 6 + await decoded_r.json().nummultby("doc1", ".b[0].a", 3) == 6 @pytest.mark.redismod -async def test_strappend_dollar(modclient: redis.Redis): +async def test_strappend_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) # Test multi - await modclient.json().strappend("doc1", "bar", "$..a") == [6, 8, None] + await decoded_r.json().strappend("doc1", "bar", "$..a") == [6, 8, None] + + res = [{"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) - await modclient.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] # Test single - await modclient.json().strappend("doc1", "baz", "$.nested1.a") == [11] + await decoded_r.json().strappend("doc1", "baz", "$.nested1.a") == [11] - await modclient.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] + res = [{"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().strappend("non_existing_doc", "$..a", "err") + await decoded_r.json().strappend("non_existing_doc", "$..a", "err") # Test multi - await modclient.json().strappend("doc1", "bar", ".*.a") == 8 - await modclient.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + await decoded_r.json().strappend("doc1", "bar", ".*.a") == 8 + res = [{"a": "foobar", "nested1": {"a": "hellobarbazbar"}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing path with pytest.raises(exceptions.ResponseError): - await modclient.json().strappend("doc1", "piu") + await decoded_r.json().strappend("doc1", "piu") @pytest.mark.redismod -async def test_strlen_dollar(modclient: redis.Redis): +async def test_strlen_dollar(decoded_r: redis.Redis): # Test multi - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", {"a": "foo", "nested1": {"a": "hello"}, "nested2": {"a": 31}} ) - assert await modclient.json().strlen("doc1", "$..a") == [3, 5, None] + assert await decoded_r.json().strlen("doc1", "$..a") == [3, 5, None] - res2 = await modclient.json().strappend("doc1", "bar", "$..a") - res1 = await modclient.json().strlen("doc1", "$..a") + res2 = await decoded_r.json().strappend("doc1", "bar", "$..a") + res1 = await decoded_r.json().strlen("doc1", "$..a") assert res1 == res2 # Test single - await modclient.json().strlen("doc1", "$.nested1.a") == [8] - await modclient.json().strlen("doc1", "$.nested2.a") == [None] + await decoded_r.json().strlen("doc1", "$.nested1.a") == [8] + await decoded_r.json().strlen("doc1", "$.nested2.a") == [None] # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().strlen("non_existing_doc", "$..a") + await decoded_r.json().strlen("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_arrappend_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_arrappend_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -532,31 +610,33 @@ async def test_arrappend_dollar(modclient: redis.Redis): }, ) # Test multi - await modclient.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert await modclient.json().get("doc1", "$") == [ + await decoded_r.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().arrappend("doc1", "$.nested1.a", "baz") == [6] + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") # Test legacy - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -566,33 +646,35 @@ async def test_arrappend_dollar(modclient: redis.Redis): }, ) # Test multi (all paths are updated, but return result of last path) - assert await modclient.json().arrappend("doc1", "..a", "bar", "racuda") == 5 + assert await decoded_r.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().arrappend("doc1", ".nested1.a", "baz") == 6 + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_arrinsert_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_arrinsert_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -602,35 +684,37 @@ async def test_arrinsert_dollar(modclient: redis.Redis): }, ) # Test multi - res = await modclient.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") + res = await decoded_r.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") assert res == [3, 5, None] - assert await modclient.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_arrlen_dollar(modclient: redis.Redis): +async def test_arrlen_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -641,20 +725,20 @@ async def test_arrlen_dollar(modclient: redis.Redis): ) # Test multi - assert await modclient.json().arrlen("doc1", "$..a") == [1, 3, None] - res = await modclient.json().arrappend("doc1", "$..a", "non", "abba", "stanza") + assert await decoded_r.json().arrlen("doc1", "$..a") == [1, 3, None] + res = await decoded_r.json().arrappend("doc1", "$..a", "non", "abba", "stanza") assert res == [4, 6, None] - await modclient.json().clear("doc1", "$.a") - assert await modclient.json().arrlen("doc1", "$..a") == [0, 6, None] + await decoded_r.json().clear("doc1", "$.a") + assert await decoded_r.json().arrlen("doc1", "$..a") == [0, 6, None] # Test single - assert await modclient.json().arrlen("doc1", "$.nested1.a") == [6] + assert await decoded_r.json().arrlen("doc1", "$.nested1.a") == [6] # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrappend("non_existing_doc", "$..a") + await decoded_r.json().arrappend("non_existing_doc", "$..a") - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -664,19 +748,19 @@ async def test_arrlen_dollar(modclient: redis.Redis): }, ) # Test multi (return result of last path) - assert await modclient.json().arrlen("doc1", "$..a") == [1, 3, None] - assert await modclient.json().arrappend("doc1", "..a", "non", "abba", "stanza") == 6 + assert await decoded_r.json().arrlen("doc1", "$..a") == [1, 3, None] + assert await decoded_r.json().arrappend("doc1", "..a", "non", "abba", "stanza") == 6 # Test single - assert await modclient.json().arrlen("doc1", ".nested1.a") == 6 + assert await decoded_r.json().arrlen("doc1", ".nested1.a") == 6 # Test missing key - assert await modclient.json().arrlen("non_existing_doc", "..a") is None + assert await decoded_r.json().arrlen("non_existing_doc", "..a") is None @pytest.mark.redismod -async def test_arrpop_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_arrpop_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -686,19 +770,18 @@ async def test_arrpop_dollar(modclient: redis.Redis): }, ) - # # # Test multi - assert await modclient.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] + # Test multi + assert await decoded_r.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrpop("non_existing_doc", "..a") + await decoded_r.json().arrpop("non_existing_doc", "..a") # # Test legacy - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -708,20 +791,19 @@ async def test_arrpop_dollar(modclient: redis.Redis): }, ) # Test multi (all paths are updated, but return result of last path) - await modclient.json().arrpop("doc1", "..a", "1") is None - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + await decoded_r.json().arrpop("doc1", "..a", "1") is None + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrpop("non_existing_doc", "..a") + await decoded_r.json().arrpop("non_existing_doc", "..a") @pytest.mark.redismod -async def test_arrtrim_dollar(modclient: redis.Redis): +async def test_arrtrim_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -731,27 +813,24 @@ async def test_arrtrim_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] + res = [{"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) - assert await modclient.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - assert await modclient.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] + res = [{"a": [], "nested1": {"a": []}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrtrim("non_existing_doc", "..a", "0", 1) + await decoded_r.json().arrtrim("non_existing_doc", "..a", "0", 1) # Test legacy - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -762,22 +841,21 @@ async def test_arrtrim_dollar(modclient: redis.Redis): ) # Test multi (all paths are updated, but return result of last path) - assert await modclient.json().arrtrim("doc1", "..a", "1", "-1") == 2 + assert await decoded_r.json().arrtrim("doc1", "..a", "1", "-1") == 2 # Test single - assert await modclient.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert await modclient.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + assert await decoded_r.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().arrtrim("non_existing_doc", "..a", 1, 1) + await decoded_r.json().arrtrim("non_existing_doc", "..a", 1, 1) @pytest.mark.redismod -async def test_objkeys_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_objkeys_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -788,26 +866,26 @@ async def test_objkeys_dollar(modclient: redis.Redis): ) # Test single - assert await modclient.json().objkeys("doc1", "$.nested1.a") == [["foo", "bar"]] + assert await decoded_r.json().objkeys("doc1", "$.nested1.a") == [["foo", "bar"]] # Test legacy - assert await modclient.json().objkeys("doc1", ".*.a") == ["foo", "bar"] + assert await decoded_r.json().objkeys("doc1", ".*.a") == ["foo", "bar"] # Test single - assert await modclient.json().objkeys("doc1", ".nested2.a") == ["baz"] + assert await decoded_r.json().objkeys("doc1", ".nested2.a") == ["baz"] # Test missing key - assert await modclient.json().objkeys("non_existing_doc", "..a") is None + assert await decoded_r.json().objkeys("non_existing_doc", "..a") is None # Test non existing doc with pytest.raises(exceptions.ResponseError): - assert await modclient.json().objkeys("non_existing_doc", "$..a") == [] + assert await decoded_r.json().objkeys("non_existing_doc", "$..a") == [] - assert await modclient.json().objkeys("doc1", "$..nowhere") == [] + assert await decoded_r.json().objkeys("doc1", "$..nowhere") == [] @pytest.mark.redismod -async def test_objlen_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_objlen_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -817,28 +895,28 @@ async def test_objlen_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().objlen("doc1", "$..a") == [None, 2, 1] + assert await decoded_r.json().objlen("doc1", "$..a") == [None, 2, 1] # Test single - assert await modclient.json().objlen("doc1", "$.nested1.a") == [2] + assert await decoded_r.json().objlen("doc1", "$.nested1.a") == [2] # Test missing key, and path with pytest.raises(exceptions.ResponseError): - await modclient.json().objlen("non_existing_doc", "$..a") + await decoded_r.json().objlen("non_existing_doc", "$..a") - assert await modclient.json().objlen("doc1", "$.nowhere") == [] + assert await decoded_r.json().objlen("doc1", "$.nowhere") == [] # Test legacy - assert await modclient.json().objlen("doc1", ".*.a") == 2 + assert await decoded_r.json().objlen("doc1", ".*.a") == 2 # Test single - assert await modclient.json().objlen("doc1", ".nested2.a") == 1 + assert await decoded_r.json().objlen("doc1", ".nested2.a") == 1 # Test missing key - assert await modclient.json().objlen("non_existing_doc", "..a") is None + assert await decoded_r.json().objlen("non_existing_doc", "..a") is None # Test missing path # with pytest.raises(exceptions.ResponseError): - await modclient.json().objlen("doc1", ".nowhere") + await decoded_r.json().objlen("doc1", ".nowhere") @pytest.mark.redismod @@ -862,23 +940,28 @@ def load_types_data(nested_key_name): @pytest.mark.redismod -async def test_type_dollar(modclient: redis.Redis): +async def test_type_dollar(decoded_r: redis.Redis): jdata, jtypes = load_types_data("a") - await modclient.json().set("doc1", "$", jdata) + await decoded_r.json().set("doc1", "$", jdata) # Test multi - assert await modclient.json().type("doc1", "$..a") == jtypes + assert_resp_response( + decoded_r, await decoded_r.json().type("doc1", "$..a"), jtypes, [jtypes] + ) # Test single - assert await modclient.json().type("doc1", "$.nested2.a") == [jtypes[1]] + res = await decoded_r.json().type("doc1", "$.nested2.a") + assert_resp_response(decoded_r, res, [jtypes[1]], [[jtypes[1]]]) # Test missing key - assert await modclient.json().type("non_existing_doc", "..a") is None + assert_resp_response( + decoded_r, await decoded_r.json().type("non_existing_doc", "..a"), None, [None] + ) @pytest.mark.redismod -async def test_clear_dollar(modclient: redis.Redis): +async def test_clear_dollar(decoded_r: redis.Redis): - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -890,14 +973,15 @@ async def test_clear_dollar(modclient: redis.Redis): ) # Test multi - assert await modclient.json().clear("doc1", "$..a") == 3 + assert await decoded_r.json().clear("doc1", "$..a") == 3 - assert await modclient.json().get("doc1", "$") == [ + res = [ {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test single - await modclient.json().set( + await decoded_r.json().set( "doc1", "$", { @@ -907,8 +991,8 @@ async def test_clear_dollar(modclient: redis.Redis): "nested3": {"a": {"baz": 50}}, }, ) - assert await modclient.json().clear("doc1", "$.nested1.a") == 1 - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().clear("doc1", "$.nested1.a") == 1 + res = [ { "nested1": {"a": {}}, "a": ["foo"], @@ -916,19 +1000,22 @@ async def test_clear_dollar(modclient: redis.Redis): "nested3": {"a": {"baz": 50}}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing path (async defaults to root) - assert await modclient.json().clear("doc1") == 1 - assert await modclient.json().get("doc1", "$") == [{}] + assert await decoded_r.json().clear("doc1") == 1 + assert_resp_response( + decoded_r, await decoded_r.json().get("doc1", "$"), [{}], [[{}]] + ) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().clear("non_existing_doc", "$..a") + await decoded_r.json().clear("non_existing_doc", "$..a") @pytest.mark.redismod -async def test_toggle_dollar(modclient: redis.Redis): - await modclient.json().set( +async def test_toggle_dollar(decoded_r: redis.Redis): + await decoded_r.json().set( "doc1", "$", { @@ -939,8 +1026,8 @@ async def test_toggle_dollar(modclient: redis.Redis): }, ) # Test multi - assert await modclient.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert await modclient.json().get("doc1", "$") == [ + assert await decoded_r.json().toggle("doc1", "$..a") == [None, 1, None, 0] + res = [ { "a": ["foo"], "nested1": {"a": True}, @@ -948,7 +1035,8 @@ async def test_toggle_dollar(modclient: redis.Redis): "nested3": {"a": False}, } ] + assert_resp_response(decoded_r, await decoded_r.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): - await modclient.json().toggle("non_existing_doc", "$..a") + await decoded_r.json().toggle("non_existing_doc", "$..a") diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index d78f74164d..75484a2791 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -2,7 +2,6 @@ import pytest import pytest_asyncio - from redis.asyncio.lock import Lock from redis.exceptions import LockError, LockNotOwnedError diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py index 3551579ec0..73ee3cf811 100644 --- a/tests/test_asyncio/test_monitor.py +++ b/tests/test_asyncio/test_monitor.py @@ -1,5 +1,4 @@ import pytest - from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise from .conftest import wait_for_command diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 3df57eb90f..edd2f6d147 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -1,5 +1,4 @@ import pytest - import redis from tests.conftest import skip_if_server_version_lt @@ -21,7 +20,6 @@ async def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert await pipe.execute() == [ True, @@ -29,7 +27,6 @@ async def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] async def test_pipeline_memoryview(self, r): diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index c2a9130e83..858576584f 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -5,14 +5,20 @@ from typing import Optional from unittest.mock import patch -import async_timeout +# the functionality is available in 3.11.x but has a major issue before +# 3.11.3. See https://github.com/redis/redis-py/issues/2633 +if sys.version_info >= (3, 11, 3): + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + import pytest import pytest_asyncio - import redis.asyncio as redis from redis.exceptions import ConnectionError from redis.typing import EncodableT -from tests.conftest import skip_if_server_version_lt +from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import get_protocol_version, skip_if_server_version_lt from .compat import create_task, mock @@ -21,7 +27,7 @@ def with_timeout(t): def wrapper(corofunc): @functools.wraps(corofunc) async def run(*args, **kwargs): - async with async_timeout.timeout(t): + async with async_timeout(t): return await corofunc(*args, **kwargs) return run @@ -416,6 +422,24 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): assert expect in info.exconly() +@pytest.mark.onlynoncluster +class TestPubSubRESP3Handler: + def my_handler(self, message): + self.message = ["my handler", message] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + async def test_push_handler(self, r): + if get_protocol_version(r) in [2, "2", None]: + return + p = r.pubsub(push_handler_func=self.my_handler) + await p.subscribe("foo") + assert await wait_for_message(p) is None + assert self.message == ["my handler", [b"subscribe", b"foo", 1]] + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + + @pytest.mark.onlynoncluster class TestPubSubAutoDecoding: """These tests only validate that we get unicode values back""" @@ -648,22 +672,19 @@ async def test_reconnect_listen(self, r: redis.Redis, pubsub): async def loop(): # must make sure the task exits - async with async_timeout.timeout(2): + async with async_timeout(2): nonlocal interrupt await pubsub.subscribe("foo") while True: - # print("loop") try: try: await pubsub.connect() await loop_step() - # print("succ") except redis.ConnectionError: await asyncio.sleep(0.1) except asyncio.CancelledError: # we use a cancel to interrupt the "listen" # when we perform a disconnect - # print("cancel", interrupt) if interrupt: interrupt = False else: @@ -677,7 +698,7 @@ async def loop_step(): task = asyncio.get_running_loop().create_task(loop()) # get the initial connect message - async with async_timeout.timeout(1): + async with async_timeout(1): message = await messages.get() assert message == { "channel": b"foo", @@ -776,7 +797,7 @@ def callback(message): if n == 1: break await asyncio.sleep(0.1) - async with async_timeout.timeout(0.1): + async with async_timeout(0.1): message = await messages.get() task.cancel() # we expect a cancelled error, not the Runtime error @@ -839,7 +860,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method): Test that a socket error will cause reconnect """ try: - async with async_timeout.timeout(self.timeout): + async with async_timeout(self.timeout): await self.mysetup(r, method) # now, disconnect the connection, and wait for it to be re-established async with self.cond: @@ -868,7 +889,7 @@ async def test_reconnect_disconnect(self, r: redis.Redis, method): Test that a manual disconnect() will cause reconnect """ try: - async with async_timeout.timeout(self.timeout): + async with async_timeout(self.timeout): await self.mysetup(r, method) # now, disconnect the connection, and wait for it to be re-established async with self.cond: @@ -896,7 +917,6 @@ async def loop(self): try: if self.state == 4: break - # print("state a ", self.state) got_msg = await self.get_message() assert got_msg if self.state in (1, 2): @@ -914,7 +934,6 @@ async def loop(self): async def loop_step_get_message(self): # get a single message via get_message message = await self.pubsub.get_message(timeout=0.1) - # print(message) if message is not None: await self.messages.put(message) return True @@ -923,7 +942,7 @@ async def loop_step_get_message(self): async def loop_step_listen(self): # get a single message via listen() try: - async with async_timeout.timeout(0.1): + async with async_timeout(0.1): async for message in self.pubsub.listen(): await self.messages.put(message) return True @@ -947,7 +966,7 @@ async def test_outer_timeout(self, r: redis.Redis): assert pubsub.connection.is_connected async def get_msg_or_timeout(timeout=0.1): - async with async_timeout.timeout(timeout): + async with async_timeout(timeout): # blocking method to return messages while True: response = await pubsub.parse_response(block=True) @@ -967,6 +986,9 @@ async def get_msg_or_timeout(timeout=0.1): # the timeout on the read should not cause disconnect assert pubsub.connection.is_connected + @pytest.mark.skipif( + sys.version_info < (3, 8), reason="requires python 3.8 or higher" + ) async def test_base_exception(self, r: redis.Redis): """ Manually trigger a BaseException inside the parser's .read_response method @@ -991,13 +1013,15 @@ async def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert pubsub.connection.is_connected - with patch("redis.asyncio.connection.PythonParser.read_response") as mock1: + with patch("redis._parsers._AsyncRESP2Parser.read_response") as mock1, patch( + "redis._parsers._AsyncHiredisParser.read_response" + ) as mock2, patch("redis._parsers._AsyncRESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") - with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2: - mock2.side_effect = BaseException("boom") + mock2.side_effect = BaseException("boom") + mock3.side_effect = BaseException("boom") - with pytest.raises(BaseException): - await get_msg() + with pytest.raises(BaseException): + await get_msg() # the timeout on the read should not cause disconnect assert pubsub.connection.is_connected diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py index 86e6ddfa0d..2912ca786c 100644 --- a/tests/test_asyncio/test_retry.py +++ b/tests/test_asyncio/test_retry.py @@ -1,5 +1,4 @@ import pytest - from redis.asyncio import Redis from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index 3776d12cb7..8375ecd787 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -1,6 +1,5 @@ import pytest import pytest_asyncio - from redis import exceptions from tests.conftest import skip_if_server_version_lt diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 8707cdf61b..149b26d958 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -5,7 +5,6 @@ from io import TextIOWrapper import pytest - import redis.asyncio as redis import redis.commands.search import redis.commands.search.aggregation as aggregations @@ -16,7 +15,12 @@ from redis.commands.search.query import GeoFilter, NumericFilter, Query from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from tests.conftest import skip_if_redis_enterprise, skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_if_redis_enterprise, + skip_ifmodversion_lt, +) WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -32,12 +36,16 @@ async def waitForIndex(env, idx, timeout=None): while True: res = await env.execute_command("FT.INFO", idx) try: - res.index("indexing") + if int(res[res.index("indexing") + 1]) == 0: + break except ValueError: break - - if int(res[res.index("indexing") + 1]) == 0: - break + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break time.sleep(delay) if timeout is not None: @@ -46,23 +54,23 @@ async def waitForIndex(env, idx, timeout=None): break -def getClient(modclient: redis.Redis): +def getClient(decoded_r: redis.Redis): """ Gets a client client attached to an index name which is ready to be created """ - return modclient + return decoded_r -async def createIndex(modclient, num_docs=100, definition=None): +async def createIndex(decoded_r, num_docs=100, definition=None): try: - await modclient.create_index( + await decoded_r.create_index( (TextField("play", weight=5.0), TextField("txt"), NumericField("chapter")), definition=definition, ) except redis.ResponseError: - await modclient.dropindex(delete_documents=True) - return createIndex(modclient, num_docs=num_docs, definition=definition) + await decoded_r.dropindex(delete_documents=True) + return createIndex(decoded_r, num_docs=num_docs, definition=definition) chapters = {} bzfp = TextIOWrapper(bz2.BZ2File(WILL_PLAY_TEXT), encoding="utf8") @@ -80,7 +88,7 @@ async def createIndex(modclient, num_docs=100, definition=None): if len(chapters) == num_docs: break - indexer = modclient.batch_indexer(chunk_size=50) + indexer = decoded_r.batch_indexer(chunk_size=50) assert isinstance(indexer, AsyncSearch.BatchIndexer) assert 50 == indexer.chunk_size @@ -90,12 +98,12 @@ async def createIndex(modclient, num_docs=100, definition=None): @pytest.mark.redismod -async def test_client(modclient: redis.Redis): +async def test_client(decoded_r: redis.Redis): num_docs = 500 - await createIndex(modclient.ft(), num_docs=num_docs) - await waitForIndex(modclient, "idx") + await createIndex(decoded_r.ft(), num_docs=num_docs) + await waitForIndex(decoded_r, "idx") # verify info - info = await modclient.ft().info() + info = await decoded_r.ft().info() for k in [ "index_name", "index_options", @@ -115,145 +123,268 @@ async def test_client(modclient: redis.Redis): ]: assert k in info - assert modclient.ft().index_name == info["index_name"] + assert decoded_r.ft().index_name == info["index_name"] assert num_docs == int(info["num_docs"]) - res = await modclient.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" + res = await decoded_r.ft().search("henry iv") + if is_resp2_connection(decoded_r): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc.play == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = await decoded_r.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = (await decoded_r.ft().search(Query("kings").no_content())).total + vtotal = ( + await decoded_r.ft().search(Query("kings").no_content().verbatim()) + ).total + assert total > vtotal + + # test in fields + txt_total = ( + await decoded_r.ft().search(Query("henry").no_content().limit_fields("txt")) + ).total + play_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play") + ) + ).total + both_total = ( + await ( + decoded_r.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + ) + ) + ).total + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await decoded_r.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = await modclient.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = (await modclient.ft().search(Query("kings").no_content())).total - vtotal = (await modclient.ft().search(Query("kings").no_content().verbatim())).total - assert total > vtotal - - # test in fields - txt_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("txt")) - ).total - play_total = ( - await modclient.ft().search(Query("henry").no_content().limit_fields("play")) - ).total - both_total = ( - await ( - modclient.ft().search( + # test in-keys + ids = [x.id for x in (await decoded_r.ft().search(Query("henry"))).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == (await decoded_r.ft().search(Query("henry king"))).total + assert ( + 3 + == ( + await decoded_r.ft().search(Query("henry king").slop(0).in_order()) + ).total + ) + assert ( + 52 + == ( + await decoded_r.ft().search(Query("king henry").slop(0).in_order()) + ).total + ) + assert 53 == (await decoded_r.ft().search(Query("henry king").slop(0))).total + assert 167 == (await decoded_r.ft().search(Query("henry king").slop(100))).total + + # test delete document + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") + + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res.total + await decoded_r.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["extra_attributes"]["play"] == "Henry IV" + assert len(doc["extra_attributes"]["txt"]) > 0 + + # test no content + res = await decoded_r.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "extra_attributes" not in doc.keys() + + # test verbatim vs no verbatim + total = (await decoded_r.ft().search(Query("kings").no_content()))[ + "total_results" + ] + vtotal = (await decoded_r.ft().search(Query("kings").no_content().verbatim()))[ + "total_results" + ] + assert total > vtotal + + # test in fields + txt_total = ( + await decoded_r.ft().search(Query("henry").no_content().limit_fields("txt")) + )["total_results"] + play_total = ( + await decoded_r.ft().search( + Query("henry").no_content().limit_fields("play") + ) + )["total_results"] + both_total = ( + await decoded_r.ft().search( Query("henry").no_content().limit_fields("play", "txt") ) + )["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = await decoded_r.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [ + x["id"] for x in (await decoded_r.ft().search(Query("henry")))["results"] + ] + assert 10 == len(ids) + subset = ids[:5] + docs = await decoded_r.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert ( + 193 == (await decoded_r.ft().search(Query("henry king")))["total_results"] + ) + assert ( + 3 + == (await decoded_r.ft().search(Query("henry king").slop(0).in_order()))[ + "total_results" + ] + ) + assert ( + 52 + == (await decoded_r.ft().search(Query("king henry").slop(0).in_order()))[ + "total_results" + ] + ) + assert ( + 53 + == (await decoded_r.ft().search(Query("henry king").slop(0)))[ + "total_results" + ] + ) + assert ( + 167 + == (await decoded_r.ft().search(Query("henry king").slop(100)))[ + "total_results" + ] ) - ).total - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = await modclient.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in (await modclient.ft().search(Query("henry"))).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = await modclient.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == (await modclient.ft().search(Query("henry king"))).total - assert ( - 3 == (await modclient.ft().search(Query("henry king").slop(0).in_order())).total - ) - assert ( - 52 - == (await modclient.ft().search(Query("king henry").slop(0).in_order())).total - ) - assert 53 == (await modclient.ft().search(Query("henry king").slop(0))).total - assert 167 == (await modclient.ft().search(Query("henry king").slop(100))).total - # test delete document - await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await modclient.ft().search(Query("death of a salesman")) - assert 1 == res.total + # test delete document + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] - assert 1 == await modclient.ft().delete_document("doc-5ghs2") - res = await modclient.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == await modclient.ft().delete_document("doc-5ghs2") + assert 1 == await decoded_r.ft().delete_document("doc-5ghs2") + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == await decoded_r.ft().delete_document("doc-5ghs2") - await modclient.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = await modclient.ft().search(Query("death of a salesman")) - assert 1 == res.total - await modclient.ft().delete_document("doc-5ghs2") + await decoded_r.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = await decoded_r.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + await decoded_r.ft().delete_document("doc-5ghs2") @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_scores(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"),)) +async def test_scores(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"),)) - await modclient.hset("doc1", mapping={"txt": "foo baz"}) - await modclient.hset("doc2", mapping={"txt": "foo bar"}) + await decoded_r.hset("doc1", mapping={"txt": "foo baz"}) + await decoded_r.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo ~bar").with_scores() - res = await modclient.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] @pytest.mark.redismod -async def test_stopwords(modclient: redis.Redis): +async def test_stopwords(decoded_r: redis.Redis): stopwords = ["foo", "bar", "baz"] - await modclient.ft().create_index((TextField("txt"),), stopwords=stopwords) - await modclient.hset("doc1", mapping={"txt": "foo bar"}) - await modclient.hset("doc2", mapping={"txt": "hello world"}) - await waitForIndex(modclient, "idx") + await decoded_r.ft().create_index((TextField("txt"),), stopwords=stopwords) + await decoded_r.hset("doc1", mapping={"txt": "foo bar"}) + await decoded_r.hset("doc2", mapping={"txt": "hello world"}) + await waitForIndex(decoded_r, "idx") q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + if is_resp2_connection(decoded_r): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] @pytest.mark.redismod -async def test_filters(modclient: redis.Redis): +async def test_filters(decoded_r: redis.Redis): await ( - modclient.ft().create_index( + decoded_r.ft().create_index( (TextField("txt"), NumericField("num"), GeoField("loc")) ) ) await ( - modclient.hset( + decoded_r.hset( "doc1", mapping={"txt": "foo bar", "num": 3.141, "loc": "-0.441,51.458"} ) ) await ( - modclient.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) + decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2, "loc": "-0.1,51.2"}) ) - await waitForIndex(modclient, "idx") + await waitForIndex(decoded_r, "idx") # Test numerical filter q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() q2 = ( @@ -261,64 +392,90 @@ async def test_filters(modclient: redis.Redis): .add_filter(NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) .no_content() ) - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + + if is_resp2_connection(decoded_r): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] # Test geo filter q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + + if is_resp2_connection(decoded_r): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res @pytest.mark.redismod -async def test_sort_by(modclient: redis.Redis): +async def test_sort_by(decoded_r: redis.Redis): await ( - modclient.ft().create_index( + decoded_r.ft().create_index( (TextField("txt"), NumericField("num", sortable=True)) ) ) - await modclient.hset("doc1", mapping={"txt": "foo bar", "num": 1}) - await modclient.hset("doc2", mapping={"txt": "foo baz", "num": 2}) - await modclient.hset("doc3", mapping={"txt": "foo qux", "num": 3}) + await decoded_r.hset("doc1", mapping={"txt": "foo bar", "num": 1}) + await decoded_r.hset("doc2", mapping={"txt": "foo baz", "num": 2}) + await decoded_r.hset("doc3", mapping={"txt": "foo qux", "num": 3}) # Test sort q1 = Query("foo").sort_by("num", asc=True).no_content() q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = await modclient.ft().search(q1), await modclient.ft().search(q2) - - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id + res1, res2 = await decoded_r.ft().search(q1), await decoded_r.ft().search(q2) + + if is_resp2_connection(decoded_r): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") -async def test_drop_index(modclient: redis.Redis): +async def test_drop_index(decoded_r: redis.Redis): """ Ensure the index gets dropped by data remains by default """ for x in range(20): for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: idx = "HaveIt" - index = getClient(modclient) + index = getClient(decoded_r) await index.hset("index:haveit", mapping={"name": "haveit"}) idef = IndexDefinition(prefix=["index:"]) await index.ft(idx).create_index((TextField("name"),), definition=idef) @@ -329,14 +486,14 @@ async def test_drop_index(modclient: redis.Redis): @pytest.mark.redismod -async def test_example(modclient: redis.Redis): +async def test_example(decoded_r: redis.Redis): # Creating the index definition and schema await ( - modclient.ft().create_index((TextField("title", weight=5.0), TextField("body"))) + decoded_r.ft().create_index((TextField("title", weight=5.0), TextField("body"))) ) # Indexing a document - await modclient.hset( + await decoded_r.hset( "doc1", mapping={ "title": "RediSearch", @@ -347,12 +504,12 @@ async def test_example(modclient: redis.Redis): # Searching with complex parameters: q = Query("search engine").verbatim().no_content().paging(0, 5) - res = await modclient.ft().search(q) + res = await decoded_r.ft().search(q) assert res is not None @pytest.mark.redismod -async def test_auto_complete(modclient: redis.Redis): +async def test_auto_complete(decoded_r: redis.Redis): n = 0 with open(TITLES_CSV) as f: cr = csv.reader(f) @@ -360,10 +517,10 @@ async def test_auto_complete(modclient: redis.Redis): for row in cr: n += 1 term, score = row[0], float(row[1]) - assert n == await modclient.ft().sugadd("ac", Suggestion(term, score=score)) + assert n == await decoded_r.ft().sugadd("ac", Suggestion(term, score=score)) - assert n == await modclient.ft().suglen("ac") - ret = await modclient.ft().sugget("ac", "bad", with_scores=True) + assert n == await decoded_r.ft().suglen("ac") + ret = await decoded_r.ft().sugget("ac", "bad", with_scores=True) assert 2 == len(ret) assert "badger" == ret[0].string assert isinstance(ret[0].score, float) @@ -372,29 +529,29 @@ async def test_auto_complete(modclient: redis.Redis): assert isinstance(ret[1].score, float) assert 1.0 != ret[1].score - ret = await modclient.ft().sugget("ac", "bad", fuzzy=True, num=10) + ret = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) assert 10 == len(ret) assert 1.0 == ret[0].score strs = {x.string for x in ret} for sug in strs: - assert 1 == await modclient.ft().sugdel("ac", sug) + assert 1 == await decoded_r.ft().sugdel("ac", sug) # make sure a second delete returns 0 for sug in strs: - assert 0 == await modclient.ft().sugdel("ac", sug) + assert 0 == await decoded_r.ft().sugdel("ac", sug) # make sure they were actually deleted - ret2 = await modclient.ft().sugget("ac", "bad", fuzzy=True, num=10) + ret2 = await decoded_r.ft().sugget("ac", "bad", fuzzy=True, num=10) for sug in ret2: assert sug.string not in strs # Test with payload - await modclient.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - await modclient.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - await modclient.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) + await decoded_r.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) + await decoded_r.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) + await decoded_r.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) sugs = await ( - modclient.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) + decoded_r.ft().sugget("ac", "pay", with_payloads=True, with_scores=True) ) assert 3 == len(sugs) for sug in sugs: @@ -403,8 +560,8 @@ async def test_auto_complete(modclient: redis.Redis): @pytest.mark.redismod -async def test_no_index(modclient: redis.Redis): - await modclient.ft().create_index( +async def test_no_index(decoded_r: redis.Redis): + await decoded_r.ft().create_index( ( TextField("field"), TextField("text", no_index=True, sortable=True), @@ -414,37 +571,60 @@ async def test_no_index(modclient: redis.Redis): ) ) - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"field": "aaa", "text": "1", "numeric": "1", "geo": "1,1", "tag": "1"}, ) - await modclient.hset( + await decoded_r.hset( "doc2", mapping={"field": "aab", "text": "2", "numeric": "2", "geo": "2,2", "tag": "2"}, ) - await waitForIndex(modclient, "idx") + await waitForIndex(decoded_r, "idx") + + if is_resp2_connection(decoded_r): + res = await decoded_r.ft().search(Query("@text:aa*")) + assert 0 == res.total - res = await modclient.ft().search(Query("@text:aa*")) - assert 0 == res.total + res = await decoded_r.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = await modclient.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = await modclient.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = await decoded_r.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] - res = await modclient.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id + res = await decoded_r.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] + + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = await decoded_r.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] # Ensure exception is raised for non-indexable, non-sortable fields with pytest.raises(Exception): @@ -458,51 +638,68 @@ async def test_no_index(modclient: redis.Redis): @pytest.mark.redismod -async def test_explain(modclient: redis.Redis): +async def test_explain(decoded_r: redis.Redis): await ( - modclient.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) + decoded_r.ft().create_index((TextField("f1"), TextField("f2"), TextField("f3"))) ) - res = await modclient.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") + res = await decoded_r.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @pytest.mark.redismod -async def test_explaincli(modclient: redis.Redis): +async def test_explaincli(decoded_r: redis.Redis): with pytest.raises(NotImplementedError): - await modclient.ft().explain_cli("foo") + await decoded_r.ft().explain_cli("foo") @pytest.mark.redismod -async def test_summarize(modclient: redis.Redis): - await createIndex(modclient.ft()) - await waitForIndex(modclient, "idx") +async def test_summarize(decoded_r: redis.Redis): + await createIndex(decoded_r.ft()) + await waitForIndex(decoded_r, "idx") q = Query("king henry").paging(0, 1) q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") - doc = sorted((await modclient.ft().search(q)).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(decoded_r): + doc = sorted((await decoded_r.ft().search(q)).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted((await modclient.ft().search(q)).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + doc = sorted((await decoded_r.ft().search(q)).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted((await decoded_r.ft().search(q))["results"])[0] + assert "Henry IV" == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted((await decoded_r.ft().search(q))["results"])[0] + assert "Henry ... " == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) @pytest.mark.redismod @skip_ifmodversion_lt("2.0.0", "search") -async def test_alias(modclient: redis.Redis): - index1 = getClient(modclient) - index2 = getClient(modclient) +async def test_alias(decoded_r: redis.Redis): + index1 = getClient(decoded_r) + index2 = getClient(decoded_r) def1 = IndexDefinition(prefix=["index1:"]) def2 = IndexDefinition(prefix=["index2:"]) @@ -515,25 +712,46 @@ async def test_alias(modclient: redis.Redis): await index1.hset("index1:lonestar", mapping={"name": "lonestar"}) await index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = (await ftindex1.search("*")).docs[0] - assert "index1:lonestar" == res.id + if is_resp2_connection(decoded_r): + res = (await ftindex1.search("*")).docs[0] + assert "index1:lonestar" == res.id - # create alias and check for results - await ftindex1.aliasadd("spaceballs") - alias_client = getClient(modclient).ft("spaceballs") - res = (await alias_client.search("*")).docs[0] - assert "index1:lonestar" == res.id + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = getClient(decoded_r).ft("spaceballs") + res = (await alias_client.search("*")).docs[0] + assert "index1:lonestar" == res.id - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await ftindex2.aliasadd("spaceballs") + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(decoded_r).ft("spaceballs") + + res = (await alias_client2.search("*")).docs[0] + assert "index2:yogurt" == res.id + else: + res = (await ftindex1.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - # update alias and ensure new results - await ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(modclient).ft("spaceballs") + # create alias and check for results + await ftindex1.aliasadd("spaceballs") + alias_client = getClient(await decoded_r).ft("spaceballs") + res = (await alias_client.search("*"))["results"][0] + assert "index1:lonestar" == res["id"] - res = (await alias_client2.search("*")).docs[0] - assert "index2:yogurt" == res.id + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + await ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(await decoded_r).ft("spaceballs") + + res = (await alias_client2.search("*"))["results"][0] + assert "index2:yogurt" == res["id"] await ftindex2.aliasdel("spaceballs") with pytest.raises(Exception): @@ -541,34 +759,51 @@ async def test_alias(modclient: redis.Redis): @pytest.mark.redismod -async def test_alias_basic(modclient: redis.Redis): +@pytest.mark.xfail(strict=False) +async def test_alias_basic(decoded_r: redis.Redis): # Creating a client with one index - client = getClient(modclient) + client = getClient(decoded_r) await client.flushdb() - index1 = getClient(modclient).ft("testAlias") + index1 = getClient(decoded_r).ft("testAlias") await index1.create_index((TextField("txt"),)) await index1.client.hset("doc1", mapping={"txt": "text goes here"}) - index2 = getClient(modclient).ft("testAlias2") + index2 = getClient(decoded_r).ft("testAlias2") await index2.create_index((TextField("txt"),)) await index2.client.hset("doc2", mapping={"txt": "text goes here"}) # add the actual alias and check await index1.aliasadd("myalias") - alias_client = getClient(modclient).ft("myalias") - res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - await index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - await index2.aliasupdate("myalias") - alias_client2 = getClient(modclient).ft("myalias") - res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) - assert "doc1" == res[0].id + alias_client = getClient(decoded_r).ft("myalias") + if is_resp2_connection(decoded_r): + res = sorted((await alias_client.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = getClient(decoded_r).ft("myalias") + res = sorted((await alias_client2.search("*")).docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted((await alias_client.search("*"))["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + await index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + await index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted( + (await alias_client2.search("*"))["results"], key=lambda x: x["id"] + ) + assert "doc1" == res[0]["id"] # delete the alias and expect an error if we try to query again await index2.aliasdel("myalias") @@ -577,56 +812,79 @@ async def test_alias_basic(modclient: redis.Redis): @pytest.mark.redismod -async def test_tags(modclient: redis.Redis): - await modclient.ft().create_index((TextField("txt"), TagField("tags"))) +async def test_tags(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("txt"), TagField("tags"))) tags = "foo,foo bar,hello;world" tags2 = "soba,ramen" - await modclient.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) - await modclient.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) - await waitForIndex(modclient, "idx") + await decoded_r.hset("doc1", mapping={"txt": "fooz barz", "tags": tags}) + await decoded_r.hset("doc2", mapping={"txt": "noodles", "tags": tags2}) + await waitForIndex(decoded_r, "idx") q = Query("@tags:{foo}") - res = await modclient.ft().search(q) - assert 1 == res.total + if is_resp2_connection(decoded_r): + res = await decoded_r.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo bar}") + res = await decoded_r.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo\\ bar}") + res = await decoded_r.ft().search(q) + assert 1 == res.total - q = Query("@tags:{foo bar}") - res = await modclient.ft().search(q) - assert 1 == res.total + q = Query("@tags:{hello\\;world}") + res = await decoded_r.ft().search(q) + assert 1 == res.total - q = Query("@tags:{foo\\ bar}") - res = await modclient.ft().search(q) - assert 1 == res.total + q2 = await decoded_r.ft().tagvals("tags") + assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + else: + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - q = Query("@tags:{hello\\;world}") - res = await modclient.ft().search(q) - assert 1 == res.total + q = Query("@tags:{foo bar}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] - q2 = await modclient.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + q = Query("@tags:{foo\\ bar}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] + + q = Query("@tags:{hello\\;world}") + res = await decoded_r.ft().search(q) + assert 1 == res["total_results"] + + q2 = await decoded_r.ft().tagvals("tags") + assert set(tags.split(",") + tags2.split(",")) == q2 @pytest.mark.redismod -async def test_textfield_sortable_nostem(modclient: redis.Redis): +async def test_textfield_sortable_nostem(decoded_r: redis.Redis): # Creating the index definition with sortable and no_stem - await modclient.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) + await decoded_r.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) # Now get the index info to confirm its contents - response = await modclient.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] + response = await decoded_r.ft().info() + if is_resp2_connection(decoded_r): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] @pytest.mark.redismod -async def test_alter_schema_add(modclient: redis.Redis): +async def test_alter_schema_add(decoded_r: redis.Redis): # Creating the index definition and schema - await modclient.ft().create_index(TextField("title")) + await decoded_r.ft().create_index(TextField("title")) # Using alter to add a field - await modclient.ft().alter_schema_add(TextField("body")) + await decoded_r.ft().alter_schema_add(TextField("body")) # Indexing a document - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"title": "MyTitle", "body": "Some content only in the body"} ) @@ -634,167 +892,236 @@ async def test_alter_schema_add(modclient: redis.Redis): q = Query("only in the body") # Ensure we find the result searching on the added body field - res = await modclient.ft().search(q) - assert 1 == res.total + res = await decoded_r.ft().search(q) + if is_resp2_connection(decoded_r): + assert 1 == res.total + else: + assert 1 == res["total_results"] @pytest.mark.redismod -async def test_spell_check(modclient: redis.Redis): - await modclient.ft().create_index((TextField("f1"), TextField("f2"))) +async def test_spell_check(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) await ( - modclient.hset( + decoded_r.hset( "doc1", mapping={"f1": "some valid content", "f2": "this is sample text"} ) ) - await modclient.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) - await waitForIndex(modclient, "idx") - - # test spellcheck - res = await modclient.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = await modclient.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = await modclient.ft().spellcheck("vlis") - assert res == {} - res = await modclient.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - await modclient.ft().dict_add("dict", "lore", "lorem", "lorm") - res = await modclient.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = await modclient.ft().spellcheck("lorm", exclude="dict") - assert res == {} + await decoded_r.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) + await waitForIndex(decoded_r, "idx") + + if is_resp2_connection(decoded_r): + # test spellcheck + res = await decoded_r.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = await decoded_r.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = await decoded_r.ft().spellcheck("vlis") + assert res == {} + res = await decoded_r.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await decoded_r.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = await decoded_r.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = await decoded_r.ft().spellcheck("impornant") + assert "important" in res["results"]["impornant"][0].keys() + + res = await decoded_r.ft().spellcheck("contnt") + assert "content" in res["results"]["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = await decoded_r.ft().spellcheck("vlis") + assert res == {"results": {"vlis": []}} + res = await decoded_r.ft().spellcheck("vlis", distance=2) + assert "valid" in res["results"]["vlis"][0].keys() + + # test spellcheck include + await decoded_r.ft().dict_add("dict", "lore", "lorem", "lorm") + res = await decoded_r.ft().spellcheck("lorm", include="dict") + assert len(res["results"]["lorm"]) == 3 + assert "lorem" in res["results"]["lorm"][0].keys() + assert "lore" in res["results"]["lorm"][1].keys() + assert "lorm" in res["results"]["lorm"][2].keys() + assert ( + res["results"]["lorm"][0]["lorem"], + res["results"]["lorm"][1]["lore"], + ) == (0.5, 0) + + # test spellcheck exclude + res = await decoded_r.ft().spellcheck("lorm", exclude="dict") + assert res == {"results": {}} @pytest.mark.redismod -async def test_dict_operations(modclient: redis.Redis): - await modclient.ft().create_index((TextField("f1"), TextField("f2"))) +async def test_dict_operations(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) # Add three items - res = await modclient.ft().dict_add("custom_dict", "item1", "item2", "item3") + res = await decoded_r.ft().dict_add("custom_dict", "item1", "item2", "item3") assert 3 == res # Remove one item - res = await modclient.ft().dict_del("custom_dict", "item2") + res = await decoded_r.ft().dict_del("custom_dict", "item2") assert 1 == res # Dump dict and inspect content - res = await modclient.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res + res = await decoded_r.ft().dict_dump("custom_dict") + assert_resp_response(decoded_r, res, ["item1", "item3"], {"item1", "item3"}) # Remove rest of the items before reload - await modclient.ft().dict_del("custom_dict", *res) + await decoded_r.ft().dict_del("custom_dict", *res) @pytest.mark.redismod -async def test_phonetic_matcher(modclient: redis.Redis): - await modclient.ft().create_index((TextField("name"),)) - await modclient.hset("doc1", mapping={"name": "Jon"}) - await modclient.hset("doc2", mapping={"name": "John"}) - - res = await modclient.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name +async def test_phonetic_matcher(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("name"),)) + await decoded_r.hset("doc1", mapping={"name": "Jon"}) + await decoded_r.hset("doc2", mapping={"name": "John"}) + + res = await decoded_r.ft().search(Query("Jon")) + if is_resp2_connection(decoded_r): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] # Drop and create index with phonetic matcher - await modclient.flushdb() - - await modclient.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - await modclient.hset("doc1", mapping={"name": "Jon"}) - await modclient.hset("doc2", mapping={"name": "John"}) - - res = await modclient.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) + await decoded_r.flushdb() + + await decoded_r.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) + await decoded_r.hset("doc1", mapping={"name": "Jon"}) + await decoded_r.hset("doc2", mapping={"name": "John"}) + + res = await decoded_r.ft().search(Query("Jon")) + if is_resp2_connection(decoded_r): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted( + d["extra_attributes"]["name"] for d in res["results"] + ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_scorer(modclient: redis.Redis): - await modclient.ft().create_index((TextField("description"),)) +async def test_scorer(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("description"),)) - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"description": "The quick brown fox jumps over the lazy dog"} ) - await modclient.hset( + await decoded_r.hset( "doc2", mapping={ "description": "Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do." # noqa }, ) - # default scorer is TFIDF - res = await modclient.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = await ( - modclient.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - ) - assert 0.1111111111111111 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = await modclient.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + if is_resp2_connection(decoded_r): + # default scorer is TFIDF + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = await ( + decoded_r.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + ) + assert 0.1111111111111111 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res.docs[0].score + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res.docs[0].score + else: + res = await decoded_r.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores() + ) + assert 0.1111111111111111 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res["results"][0]["score"] + res = await decoded_r.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("DOCSCORE").with_scores() + ) + assert 1.0 == res["results"][0]["score"] + res = await decoded_r.ft().search( + Query("quick").scorer("HAMMING").with_scores() + ) + assert 0.0 == res["results"][0]["score"] @pytest.mark.redismod -async def test_get(modclient: redis.Redis): - await modclient.ft().create_index((TextField("f1"), TextField("f2"))) +async def test_get(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("f1"), TextField("f2"))) - assert [None] == await modclient.ft().get("doc1") - assert [None, None] == await modclient.ft().get("doc2", "doc1") + assert [None] == await decoded_r.ft().get("doc1") + assert [None, None] == await decoded_r.ft().get("doc2", "doc1") - await modclient.hset( + await decoded_r.hset( "doc1", mapping={"f1": "some valid content dd1", "f2": "this is sample text f1"} ) - await modclient.hset( + await decoded_r.hset( "doc2", mapping={"f1": "some valid content dd2", "f2": "this is sample text f2"} ) assert [ ["f1", "some valid content dd2", "f2", "this is sample text f2"] - ] == await modclient.ft().get("doc2") + ] == await decoded_r.ft().get("doc2") assert [ ["f1", "some valid content dd1", "f2", "this is sample text f1"], ["f1", "some valid content dd2", "f2", "this is sample text f2"], - ] == await modclient.ft().get("doc1", "doc2") + ] == await decoded_r.ft().get("doc1", "doc2") @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.2.0", "search") -async def test_config(modclient: redis.Redis): - assert await modclient.ft().config_set("TIMEOUT", "100") +async def test_config(decoded_r: redis.Redis): + assert await decoded_r.ft().config_set("TIMEOUT", "100") with pytest.raises(redis.ResponseError): - await modclient.ft().config_set("TIMEOUT", "null") - res = await modclient.ft().config_get("*") + await decoded_r.ft().config_set("TIMEOUT", "null") + res = await decoded_r.ft().config_get("*") assert "100" == res["TIMEOUT"] - res = await modclient.ft().config_get("TIMEOUT") + res = await decoded_r.ft().config_get("TIMEOUT") assert "100" == res["TIMEOUT"] @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_aggregations_groupby(modclient: redis.Redis): +async def test_aggregations_groupby(decoded_r: redis.Redis): # Creating the index definition and schema - await modclient.ft().create_index( + await decoded_r.ft().create_index( ( NumericField("random_num"), TextField("title"), @@ -804,7 +1131,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): ) # Indexing a document - await modclient.hset( + await decoded_r.hset( "search", mapping={ "title": "RediSearch", @@ -813,7 +1140,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): "random_num": 10, }, ) - await modclient.hset( + await decoded_r.hset( "ai", mapping={ "title": "RedisAI", @@ -822,7 +1149,7 @@ async def test_aggregations_groupby(modclient: redis.Redis): "random_num": 3, }, ) - await modclient.hset( + await decoded_r.hset( "json", mapping={ "title": "RedisJson", @@ -833,207 +1160,408 @@ async def test_aggregations_groupby(modclient: redis.Redis): ) for dialect in [1, 2]: - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count()) - .dialect(dialect) - ) + if is_resp2_connection(decoded_r): + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinct("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.count_distinctish("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.sum("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.min("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.max("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.avg("@random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "7" # (10+3+8)/3 + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "7" # (10+3+8)/3 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.stddev("random_num")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.quantile("@random_num", 0.5)) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.tolist("@title")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.first_value("@title").alias("first")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) - res = (await modclient.ft().aggregate(req)).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = ( - aggregations.AggregateRequest("redis") - .group_by("@parent", reducers.random_sample("@title", 2).alias("random")) - .dialect(dialect) - ) + req = ( + aggregations.AggregateRequest("redis") + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req)).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count()) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount"] == "3" - res = (await modclient.ft().aggregate(req)).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinct("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.count_distinctish("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliascount_distinctishtitle"] + == "3" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.sum("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.min("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.max("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.avg("@random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.stddev("random_num")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasstddevrandom_num"] + == "3.60555127546" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.quantile("@random_num", 0.5)) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] + == "8" + ) + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.tolist("@title")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = ( + aggregations.AggregateRequest("redis") + .group_by("@parent", reducers.first_value("@title").alias("first")) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + + req = ( + aggregations.AggregateRequest("redis") + .group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + .dialect(dialect) + ) + + res = (await decoded_r.ft().aggregate(req))["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert "random" in res["extra_attributes"].keys() + assert len(res["extra_attributes"]["random"]) == 2 + assert res["extra_attributes"]["random"][0] in [ + "RediSearch", + "RedisAI", + "RedisJson", + ] @pytest.mark.redismod -async def test_aggregations_sort_by_and_limit(modclient: redis.Redis): - await modclient.ft().create_index((TextField("t1"), TextField("t2"))) +async def test_aggregations_sort_by_and_limit(decoded_r: redis.Redis): + await decoded_r.ft().create_index((TextField("t1"), TextField("t2"))) - await modclient.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) - await modclient.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) + await decoded_r.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) + await decoded_r.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = await modclient.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + if is_resp2_connection(decoded_r): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = await decoded_r.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = await decoded_r.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await decoded_r.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = (await decoded_r.ft().aggregate(req))["results"] + assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} + assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = await modclient.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = (await decoded_r.ft().aggregate(req))["results"] + assert res[0]["extra_attributes"] == {"t1": "a"} + assert res[1]["extra_attributes"] == {"t1": "b"} - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = await modclient.ft().aggregate(req) - assert len(res.rows) == 1 + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = await decoded_r.ft().aggregate(req) + assert len(res["results"]) == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = await modclient.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = await decoded_r.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"t1": "b"} @pytest.mark.redismod @pytest.mark.experimental -async def test_withsuffixtrie(modclient: redis.Redis): +async def test_withsuffixtrie(decoded_r: redis.Redis): # create index - assert await modclient.ft().create_index((TextField("txt"),)) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert await modclient.ft().dropindex("idx") - - # create withsuffixtrie index (text field) - assert await modclient.ft().create_index((TextField("t", withsuffixtrie=True))) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert await modclient.ft().dropindex("idx") - - # create withsuffixtrie index (tag field) - assert await modclient.ft().create_index((TagField("t", withsuffixtrie=True))) - await waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = await modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await decoded_r.ft().create_index((TextField("txt"),)) + await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + if is_resp2_connection(decoded_r): + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (text field) + assert await decoded_r.ft().create_index((TextField("t", withsuffixtrie=True))) + await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await decoded_r.ft().create_index((TagField("t", withsuffixtrie=True))) + await waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert await decoded_r.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert await decoded_r.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert await decoded_r.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(decoded_r, getattr(decoded_r.ft(), "index_name", "idx")) + info = await decoded_r.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] @pytest.mark.redismod @skip_if_redis_enterprise() -async def test_search_commands_in_pipeline(modclient: redis.Redis): - p = await modclient.ft().pipeline() +async def test_search_commands_in_pipeline(decoded_r: redis.Redis): + p = await decoded_r.ft().pipeline() p.create_index((TextField("txt"),)) p.hset("doc1", mapping={"txt": "foo bar"}) p.hset("doc2", mapping={"txt": "foo bar"}) q = Query("foo bar").with_payloads() await p.search(q) res = await p.execute() - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] + if is_resp2_connection(decoded_r): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["extra_attributes"] + == res[3]["results"][1]["extra_attributes"] + == {"txt": "foo bar"} + ) @pytest.mark.redismod -async def test_query_timeout(modclient: redis.Redis): +async def test_query_timeout(decoded_r: redis.Redis): q1 = Query("foo").timeout(5000) assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): - await modclient.ft().search(q2) + await decoded_r.ft().search(q2) diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 5a0533ba05..2091f2cb87 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -2,7 +2,6 @@ import pytest import pytest_asyncio - import redis.asyncio.sentinel from redis import exceptions from redis.asyncio.sentinel import ( @@ -15,7 +14,7 @@ @pytest_asyncio.fixture(scope="module") def master_ip(master_host): - yield socket.gethostbyname(master_host) + yield socket.gethostbyname(master_host[0]) class SentinelTestClient: diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index a6e9f37a63..e784690c77 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -1,7 +1,6 @@ import socket import pytest - from redis.asyncio.retry import Retry from redis.asyncio.sentinel import SentinelManagedConnection from redis.backoff import NoBackoff diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index a7109938f2..48ffdfd889 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -2,230 +2,253 @@ from time import sleep import pytest - import redis.asyncio as redis -from tests.conftest import skip_ifmodversion_lt +from tests.conftest import ( + assert_resp_response, + is_resp2_connection, + skip_ifmodversion_lt, +) @pytest.mark.redismod -async def test_create(modclient: redis.Redis): - assert await modclient.ts().create(1) - assert await modclient.ts().create(2, retention_msecs=5) - assert await modclient.ts().create(3, labels={"Redis": "Labs"}) - assert await modclient.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) - info = await modclient.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] +async def test_create(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + assert await decoded_r.ts().create(2, retention_msecs=5) + assert await decoded_r.ts().create(3, labels={"Redis": "Labs"}) + assert await decoded_r.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) + info = await decoded_r.ts().info(4) + assert_resp_response( + decoded_r, 20, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Series" == info["labels"]["Time"] # Test for a chunk size of 128 Bytes - assert await modclient.ts().create("time-serie-1", chunk_size=128) - info = await modclient.ts().info("time-serie-1") - assert 128, info.chunk_size + assert await decoded_r.ts().create("time-serie-1", chunk_size=128) + info = await decoded_r.ts().info("time-serie-1") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def test_create_duplicate_policy(modclient: redis.Redis): +async def test_create_duplicate_policy(decoded_r: redis.Redis): # Test for duplicate policy for duplicate_policy in ["block", "last", "first", "min", "max"]: ts_name = f"time-serie-ooo-{duplicate_policy}" - assert await modclient.ts().create(ts_name, duplicate_policy=duplicate_policy) - info = await modclient.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy + assert await decoded_r.ts().create(ts_name, duplicate_policy=duplicate_policy) + info = await decoded_r.ts().info(ts_name) + assert_resp_response( + decoded_r, + duplicate_policy, + info.get("duplicate_policy"), + info.get("duplicatePolicy"), + ) @pytest.mark.redismod -async def test_alter(modclient: redis.Redis): - assert await modclient.ts().create(1) - res = await modclient.ts().info(1) - assert 0 == res.retention_msecs - assert await modclient.ts().alter(1, retention_msecs=10) - res = await modclient.ts().info(1) - assert {} == res.labels - res = await modclient.ts().info(1) - assert 10 == res.retention_msecs - assert await modclient.ts().alter(1, labels={"Time": "Series"}) - res = await modclient.ts().info(1) - assert "Series" == res.labels["Time"] - res = await modclient.ts().info(1) - assert 10 == res.retention_msecs +async def test_alter(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + res = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 0, res.get("retention_msecs"), res.get("retentionTime") + ) + assert await decoded_r.ts().alter(1, retention_msecs=10) + res = await decoded_r.ts().info(1) + assert {} == (await decoded_r.ts().info(1))["labels"] + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert await decoded_r.ts().alter(1, labels={"Time": "Series"}) + res = await decoded_r.ts().info(1) + assert "Series" == (await decoded_r.ts().info(1))["labels"]["Time"] + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 10, info.get("retention_msecs"), info.get("retentionTime") + ) @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def test_alter_diplicate_policy(modclient: redis.Redis): - assert await modclient.ts().create(1) - info = await modclient.ts().info(1) - assert info.duplicate_policy is None - assert await modclient.ts().alter(1, duplicate_policy="min") - info = await modclient.ts().info(1) - assert "min" == info.duplicate_policy +async def test_alter_diplicate_policy(decoded_r: redis.Redis): + assert await decoded_r.ts().create(1) + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) + assert await decoded_r.ts().alter(1, duplicate_policy="min") + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod -async def test_add(modclient: redis.Redis): - assert 1 == await modclient.ts().add(1, 1, 1) - assert 2 == await modclient.ts().add(2, 2, 3, retention_msecs=10) - assert 3 == await modclient.ts().add(3, 3, 2, labels={"Redis": "Labs"}) - assert 4 == await modclient.ts().add( +async def test_add(decoded_r: redis.Redis): + assert 1 == await decoded_r.ts().add(1, 1, 1) + assert 2 == await decoded_r.ts().add(2, 2, 3, retention_msecs=10) + assert 3 == await decoded_r.ts().add(3, 3, 2, labels={"Redis": "Labs"}) + assert 4 == await decoded_r.ts().add( 4, 4, 2, retention_msecs=10, labels={"Redis": "Labs", "Time": "Series"} ) - res = await modclient.ts().add(5, "*", 1) + res = await decoded_r.ts().add(5, "*", 1) assert abs(time.time() - round(float(res) / 1000)) < 1.0 - info = await modclient.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] + info = await decoded_r.ts().info(4) + assert_resp_response( + decoded_r, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Labs" == info["labels"]["Redis"] # Test for a chunk size of 128 Bytes on TS.ADD - assert await modclient.ts().add("time-serie-1", 1, 10.0, chunk_size=128) - info = await modclient.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert await decoded_r.ts().add("time-serie-1", 1, 10.0, chunk_size=128) + info = await decoded_r.ts().info("time-serie-1") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def test_add_duplicate_policy(modclient: redis.Redis): +async def test_add_duplicate_policy(r: redis.Redis): # Test for duplicate policy BLOCK - assert 1 == await modclient.ts().add("time-serie-add-ooo-block", 1, 5.0) + assert 1 == await r.ts().add("time-serie-add-ooo-block", 1, 5.0) with pytest.raises(Exception): - await modclient.ts().add( - "time-serie-add-ooo-block", 1, 5.0, duplicate_policy="block" - ) + await r.ts().add("time-serie-add-ooo-block", 1, 5.0, duplicate_policy="block") # Test for duplicate policy LAST - assert 1 == await modclient.ts().add("time-serie-add-ooo-last", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-last", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" ) - res = await modclient.ts().get("time-serie-add-ooo-last") + res = await r.ts().get("time-serie-add-ooo-last") assert 10.0 == res[1] # Test for duplicate policy FIRST - assert 1 == await modclient.ts().add("time-serie-add-ooo-first", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-first", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" ) - res = await modclient.ts().get("time-serie-add-ooo-first") + res = await r.ts().get("time-serie-add-ooo-first") assert 5.0 == res[1] # Test for duplicate policy MAX - assert 1 == await modclient.ts().add("time-serie-add-ooo-max", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-max", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" ) - res = await modclient.ts().get("time-serie-add-ooo-max") + res = await r.ts().get("time-serie-add-ooo-max") assert 10.0 == res[1] # Test for duplicate policy MIN - assert 1 == await modclient.ts().add("time-serie-add-ooo-min", 1, 5.0) - assert 1 == await modclient.ts().add( + assert 1 == await r.ts().add("time-serie-add-ooo-min", 1, 5.0) + assert 1 == await r.ts().add( "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" ) - res = await modclient.ts().get("time-serie-add-ooo-min") + res = await r.ts().get("time-serie-add-ooo-min") assert 5.0 == res[1] @pytest.mark.redismod -async def test_madd(modclient: redis.Redis): - await modclient.ts().create("a") - assert [1, 2, 3] == await modclient.ts().madd( +async def test_madd(decoded_r: redis.Redis): + await decoded_r.ts().create("a") + assert [1, 2, 3] == await decoded_r.ts().madd( [("a", 1, 5), ("a", 2, 10), ("a", 3, 15)] ) @pytest.mark.redismod -async def test_incrby_decrby(modclient: redis.Redis): +async def test_incrby_decrby(decoded_r: redis.Redis): for _ in range(100): - assert await modclient.ts().incrby(1, 1) + assert await decoded_r.ts().incrby(1, 1) sleep(0.001) - assert 100 == (await modclient.ts().get(1))[1] + assert 100 == (await decoded_r.ts().get(1))[1] for _ in range(100): - assert await modclient.ts().decrby(1, 1) + assert await decoded_r.ts().decrby(1, 1) sleep(0.001) - assert 0 == (await modclient.ts().get(1))[1] + assert 0 == (await decoded_r.ts().get(1))[1] - assert await modclient.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == await modclient.ts().get(2) - assert await modclient.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == await modclient.ts().get(2) - assert await modclient.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == await modclient.ts().get(2) + assert await decoded_r.ts().incrby(2, 1.5, timestamp=5) + assert_resp_response(decoded_r, await decoded_r.ts().get(2), (5, 1.5), [5, 1.5]) + assert await decoded_r.ts().incrby(2, 2.25, timestamp=7) + assert_resp_response(decoded_r, await decoded_r.ts().get(2), (7, 3.75), [7, 3.75]) + assert await decoded_r.ts().decrby(2, 1.5, timestamp=15) + assert_resp_response(decoded_r, await decoded_r.ts().get(2), (15, 2.25), [15, 2.25]) # Test for a chunk size of 128 Bytes on TS.INCRBY - assert await modclient.ts().incrby("time-serie-1", 10, chunk_size=128) - info = await modclient.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert await decoded_r.ts().incrby("time-serie-1", 10, chunk_size=128) + info = await decoded_r.ts().info("time-serie-1") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) # Test for a chunk size of 128 Bytes on TS.DECRBY - assert await modclient.ts().decrby("time-serie-2", 10, chunk_size=128) - info = await modclient.ts().info("time-serie-2") - assert 128 == info.chunk_size + assert await decoded_r.ts().decrby("time-serie-2", 10, chunk_size=128) + info = await decoded_r.ts().info("time-serie-2") + assert_resp_response(decoded_r, 128, info.get("chunk_size"), info.get("chunkSize")) @pytest.mark.redismod -async def test_create_and_delete_rule(modclient: redis.Redis): +async def test_create_and_delete_rule(decoded_r: redis.Redis): # test rule creation time = 100 - await modclient.ts().create(1) - await modclient.ts().create(2) - await modclient.ts().createrule(1, 2, "avg", 100) + await decoded_r.ts().create(1) + await decoded_r.ts().create(2) + await decoded_r.ts().createrule(1, 2, "avg", 100) for i in range(50): - await modclient.ts().add(1, time + i * 2, 1) - await modclient.ts().add(1, time + i * 2 + 1, 2) - await modclient.ts().add(1, time * 2, 1.5) - assert round((await modclient.ts().get(2))[1], 5) == 1.5 - info = await modclient.ts().info(1) - assert info.rules[0][1] == 100 + await decoded_r.ts().add(1, time + i * 2, 1) + await decoded_r.ts().add(1, time + i * 2 + 1, 2) + await decoded_r.ts().add(1, time * 2, 1.5) + assert round((await decoded_r.ts().get(2))[1], 5) == 1.5 + info = await decoded_r.ts().info(1) + if is_resp2_connection(decoded_r): + assert info.rules[0][1] == 100 + else: + assert info["rules"]["2"][0] == 100 # test rule deletion - await modclient.ts().deleterule(1, 2) - info = await modclient.ts().info(1) - assert not info.rules + await decoded_r.ts().deleterule(1, 2) + info = await decoded_r.ts().info(1) + assert not info["rules"] @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_del_range(modclient: redis.Redis): +async def test_del_range(decoded_r: redis.Redis): try: - await modclient.ts().delete("test", 0, 100) + await decoded_r.ts().delete("test", 0, 100) except Exception as e: assert e.__str__() != "" for i in range(100): - await modclient.ts().add(1, i, i % 7) - assert 22 == await modclient.ts().delete(1, 0, 21) - assert [] == await modclient.ts().range(1, 0, 21) - assert [(22, 1.0)] == await modclient.ts().range(1, 22, 22) + await decoded_r.ts().add(1, i, i % 7) + assert 22 == await decoded_r.ts().delete(1, 0, 21) + assert [] == await decoded_r.ts().range(1, 0, 21) + assert_resp_response( + decoded_r, await decoded_r.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]] + ) @pytest.mark.redismod -async def test_range(modclient: redis.Redis): +async def test_range(r: redis.Redis): for i in range(100): - await modclient.ts().add(1, i, i % 7) - assert 100 == len(await modclient.ts().range(1, 0, 200)) + await r.ts().add(1, i, i % 7) + assert 100 == len(await r.ts().range(1, 0, 200)) for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - assert 200 == len(await modclient.ts().range(1, 0, 500)) + await r.ts().add(1, i + 200, i % 7) + assert 200 == len(await r.ts().range(1, 0, 500)) # last sample isn't returned assert 20 == len( - await modclient.ts().range( - 1, 0, 500, aggregation_type="avg", bucket_size_msec=10 - ) + await r.ts().range(1, 0, 500, aggregation_type="avg", bucket_size_msec=10) ) - assert 10 == len(await modclient.ts().range(1, 0, 500, count=10)) + assert 10 == len(await r.ts().range(1, 0, 500, count=10)) @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_range_advanced(modclient: redis.Redis): +async def test_range_advanced(decoded_r: redis.Redis): for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(1, i + 200, i % 7) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(1, i + 200, i % 7) assert 2 == len( - await modclient.ts().range( + await decoded_r.ts().range( 1, 0, 500, @@ -234,35 +257,38 @@ async def test_range_advanced(modclient: redis.Redis): filter_by_max_value=2, ) ) - assert [(0, 10.0), (10, 1.0)] == await modclient.ts().range( + res = await decoded_r.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" ) - assert [(0, 5.0), (5, 6.0)] == await modclient.ts().range( + assert_resp_response(decoded_r, res, [(0, 10.0), (10, 1.0)], [[0, 10.0], [10, 1.0]]) + res = await decoded_r.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert [(0, 2.55), (10, 3.0)] == await modclient.ts().range( + assert_resp_response(decoded_r, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = await decoded_r.ts().range( 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 ) + assert_resp_response(decoded_r, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) @pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_rev_range(modclient: redis.Redis): +async def test_rev_range(decoded_r: redis.Redis): for i in range(100): - await modclient.ts().add(1, i, i % 7) - assert 100 == len(await modclient.ts().range(1, 0, 200)) + await decoded_r.ts().add(1, i, i % 7) + assert 100 == len(await decoded_r.ts().range(1, 0, 200)) for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - assert 200 == len(await modclient.ts().range(1, 0, 500)) + await decoded_r.ts().add(1, i + 200, i % 7) + assert 200 == len(await decoded_r.ts().range(1, 0, 500)) # first sample isn't returned assert 20 == len( - await modclient.ts().revrange( + await decoded_r.ts().revrange( 1, 0, 500, aggregation_type="avg", bucket_size_msec=10 ) ) - assert 10 == len(await modclient.ts().revrange(1, 0, 500, count=10)) + assert 10 == len(await decoded_r.ts().revrange(1, 0, 500, count=10)) assert 2 == len( - await modclient.ts().revrange( + await decoded_r.ts().revrange( 1, 0, 500, @@ -271,287 +297,471 @@ async def test_rev_range(modclient: redis.Redis): filter_by_max_value=2, ) ) - assert [(10, 1.0), (0, 10.0)] == await modclient.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + assert_resp_response( + decoded_r, + await decoded_r.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ), + [(10, 1.0), (0, 10.0)], + [[10, 1.0], [0, 10.0]], ) - assert [(1, 10.0), (0, 1.0)] == await modclient.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + assert_resp_response( + decoded_r, + await decoded_r.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ), + [(1, 10.0), (0, 1.0)], + [[1, 10.0], [0, 1.0]], ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def testMultiRange(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) - await modclient.ts().create( +async def test_multi_range(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) + await decoded_r.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} ) for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(2, i, i % 11) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(2, i, i % 11) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(decoded_r): + assert 100 == len(res[0]["1"][1]) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - res = await modclient.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = await decoded_r.ts().mrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert 100 == len(res["1"][2]) + + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) - # test withlabels - assert {} == res[0]["1"][0] - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) + + # test withlabels + assert {} == res["1"][0] + res = await decoded_r.ts().mrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res["1"][0] @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_multi_range_advanced(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) - await modclient.ts().create( +async def test_multi_range_advanced(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) + await decoded_r.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} ) for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(2, i, i % 11) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(2, i, i % 11) # test with selected labels - res = await modclient.ts().mrange( + res = await decoded_r.ts().mrange( 0, 200, filters=["Test=This"], select_labels=["team"] ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(decoded_r): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test with filterby - res = await modclient.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + # test with filterby + res = await decoded_r.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - # test groupby - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - - # test align - res = await modclient.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = await modclient.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + # test groupby + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + + # test align + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test with filterby + res = await decoded_r.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[15, 1.0], [16, 2.0]] == res["1"][2] + + # test groupby + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] + + # test align + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[0, 10.0], [10, 1.0]] == res["1"][2] + res = await decoded_r.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [[0, 5.0], [5, 6.0]] == res["1"][2] @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") -async def test_multi_reverse_range(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This", "team": "ny"}) - await modclient.ts().create( +async def test_multi_reverse_range(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This", "team": "ny"}) + await decoded_r.ts().create( 2, labels={"Test": "This", "Taste": "That", "team": "sf"} ) for i in range(100): - await modclient.ts().add(1, i, i % 7) - await modclient.ts().add(2, i, i % 11) + await decoded_r.ts().add(1, i, i % 7) + await decoded_r.ts().add(2, i, i % 11) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(decoded_r): + assert 100 == len(res[0]["1"][1]) - res = await modclient.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - await modclient.ts().add(1, i + 200, i % 7) - res = await modclient.ts().mrevrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrevrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] - # test withlabels - res = await modclient.ts().mrevrange( - 0, 200, filters=["Test=This"], with_labels=True - ) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - # test with selected labels - res = await modclient.ts().mrevrange( - 0, 200, filters=["Test=This"], select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + # test with selected labels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test filterby - res = await modclient.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + # test filterby + res = await decoded_r.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] - # test groupby - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] - res = await modclient.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] - - # test align - res = await modclient.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] - res = await modclient.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=1, - ) - assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + # test groupby + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + + # test align + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + else: + assert 100 == len(res["1"][2]) + + res = await decoded_r.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + await decoded_r.ts().add(1, i + 200, i % 7) + res = await decoded_r.ts().mrevrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) + assert {} == res["1"][0] + + # test withlabels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res["1"][0] + + # test with selected labels + res = await decoded_r.ts().mrevrange( + 0, 200, filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test filterby + res = await decoded_r.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[16, 2.0], [15, 1.0]] == res["1"][2] + + # test groupby + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[3, 6.0], [2, 4.0], [1, 2.0], [0, 0.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["Test=This"][3] + res = await decoded_r.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=ny"][3] + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=sf"][3] + + # test align + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[10, 1.0], [0, 10.0]] == res["1"][2] + res = await decoded_r.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [[1, 10.0], [0, 1.0]] == res["1"][2] @pytest.mark.redismod -async def test_get(modclient: redis.Redis): +async def test_get(decoded_r: redis.Redis): name = "test" - await modclient.ts().create(name) - assert await modclient.ts().get(name) is None - await modclient.ts().add(name, 2, 3) - assert 2 == (await modclient.ts().get(name))[0] - await modclient.ts().add(name, 3, 4) - assert 4 == (await modclient.ts().get(name))[1] + await decoded_r.ts().create(name) + assert not await decoded_r.ts().get(name) + await decoded_r.ts().add(name, 2, 3) + assert 2 == (await decoded_r.ts().get(name))[0] + await decoded_r.ts().add(name, 3, 4) + assert 4 == (await decoded_r.ts().get(name))[1] @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_mget(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This"}) - await modclient.ts().create(2, labels={"Test": "This", "Taste": "That"}) - act_res = await modclient.ts().mget(["Test=This"]) +async def test_mget(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This"}) + await decoded_r.ts().create(2, labels={"Test": "This", "Taste": "That"}) + act_res = await decoded_r.ts().mget(["Test=This"]) exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res - await modclient.ts().add(1, "*", 15) - await modclient.ts().add(2, "*", 25) - res = await modclient.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] - res = await modclient.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] + exp_res_resp3 = {"1": [{}, []], "2": [{}, []]} + assert_resp_response(decoded_r, act_res, exp_res, exp_res_resp3) + await decoded_r.ts().add(1, "*", 15) + await decoded_r.ts().add(2, "*", 25) + res = await decoded_r.ts().mget(["Test=This"]) + if is_resp2_connection(decoded_r): + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + else: + assert 15 == res["1"][1][1] + assert 25 == res["2"][1][1] + res = await decoded_r.ts().mget(["Taste=That"]) + if is_resp2_connection(decoded_r): + assert 25 == res[0]["2"][2] + else: + assert 25 == res["2"][1][1] # test with_labels - assert {} == res[0]["2"][0] - res = await modclient.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + if is_resp2_connection(decoded_r): + assert {} == res[0]["2"][0] + else: + assert {} == res["2"][0] + res = await decoded_r.ts().mget(["Taste=That"], with_labels=True) + if is_resp2_connection(decoded_r): + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + else: + assert {"Taste": "That", "Test": "This"} == res["2"][0] @pytest.mark.redismod -async def test_info(modclient: redis.Redis): - await modclient.ts().create( +async def test_info(decoded_r: redis.Redis): + await decoded_r.ts().create( 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) - info = await modclient.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, 5, info.get("retention_msecs"), info.get("retentionTime") + ) + assert info["labels"]["currentLabel"] == "currentData" @pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") -async def testInfoDuplicatePolicy(modclient: redis.Redis): - await modclient.ts().create( +async def testInfoDuplicatePolicy(decoded_r: redis.Redis): + await decoded_r.ts().create( 1, retention_msecs=5, labels={"currentLabel": "currentData"} ) - info = await modclient.ts().info(1) - assert info.duplicate_policy is None + info = await decoded_r.ts().info(1) + assert_resp_response( + decoded_r, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) - await modclient.ts().create("time-serie-2", duplicate_policy="min") - info = await modclient.ts().info("time-serie-2") - assert "min" == info.duplicate_policy + await decoded_r.ts().create("time-serie-2", duplicate_policy="min") + info = await decoded_r.ts().info("time-serie-2") + assert_resp_response( + decoded_r, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) @pytest.mark.redismod @pytest.mark.onlynoncluster -async def test_query_index(modclient: redis.Redis): - await modclient.ts().create(1, labels={"Test": "This"}) - await modclient.ts().create(2, labels={"Test": "This", "Taste": "That"}) - assert 2 == len(await modclient.ts().queryindex(["Test=This"])) - assert 1 == len(await modclient.ts().queryindex(["Taste=That"])) - assert [2] == await modclient.ts().queryindex(["Taste=That"]) +async def test_query_index(decoded_r: redis.Redis): + await decoded_r.ts().create(1, labels={"Test": "This"}) + await decoded_r.ts().create(2, labels={"Test": "This", "Taste": "That"}) + assert 2 == len(await decoded_r.ts().queryindex(["Test=This"])) + assert 1 == len(await decoded_r.ts().queryindex(["Taste=That"])) + assert_resp_response( + decoded_r, await decoded_r.ts().queryindex(["Taste=That"]), [2], {"2"} + ) # @pytest.mark.redismod -# async def test_pipeline(modclient: redis.Redis): -# pipeline = await modclient.ts().pipeline() +# async def test_pipeline(r: redis.Redis): +# pipeline = await r.ts().pipeline() # pipeline.create("with_pipeline") # for i in range(100): # pipeline.add("with_pipeline", i, 1.1 * i) # pipeline.execute() -# info = await modclient.ts().info("with_pipeline") +# info = await r.ts().info("with_pipeline") # assert info.lastTimeStamp == 99 # assert info.total_samples == 100 -# assert await modclient.ts().get("with_pipeline")[1] == 99 * 1.1 +# assert await r.ts().get("with_pipeline")[1] == 99 * 1.1 @pytest.mark.redismod -async def test_uncompressed(modclient: redis.Redis): - await modclient.ts().create("compressed") - await modclient.ts().create("uncompressed", uncompressed=True) - compressed_info = await modclient.ts().info("compressed") - uncompressed_info = await modclient.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage +async def test_uncompressed(decoded_r: redis.Redis): + await decoded_r.ts().create("compressed") + await decoded_r.ts().create("uncompressed", uncompressed=True) + compressed_info = await decoded_r.ts().info("compressed") + uncompressed_info = await decoded_r.ts().info("uncompressed") + if is_resp2_connection(decoded_r): + assert compressed_info.memory_usage != uncompressed_info.memory_usage + else: + assert compressed_info["memoryUsage"] != uncompressed_info["memoryUsage"] diff --git a/tests/test_bloom.py b/tests/test_bloom.py index 30d3219404..464a946f54 100644 --- a/tests/test_bloom.py +++ b/tests/test_bloom.py @@ -1,12 +1,11 @@ from math import inf import pytest - import redis.commands.bf from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt def intlist(obj): @@ -14,18 +13,17 @@ def intlist(obj): @pytest.fixture -def client(modclient): - assert isinstance(modclient.bf(), redis.commands.bf.BFBloom) - assert isinstance(modclient.cf(), redis.commands.bf.CFBloom) - assert isinstance(modclient.cms(), redis.commands.bf.CMSBloom) - assert isinstance(modclient.tdigest(), redis.commands.bf.TDigestBloom) - assert isinstance(modclient.topk(), redis.commands.bf.TOPKBloom) +def client(decoded_r): + assert isinstance(decoded_r.bf(), redis.commands.bf.BFBloom) + assert isinstance(decoded_r.cf(), redis.commands.bf.CFBloom) + assert isinstance(decoded_r.cms(), redis.commands.bf.CMSBloom) + assert isinstance(decoded_r.tdigest(), redis.commands.bf.TDigestBloom) + assert isinstance(decoded_r.topk(), redis.commands.bf.TOPKBloom) - modclient.flushdb() - return modclient + decoded_r.flushdb() + return decoded_r -@pytest.mark.redismod def test_create(client): """Test CREATE/RESERVE calls""" assert client.bf().create("bloom", 0.01, 1000) @@ -40,7 +38,6 @@ def test_create(client): assert client.topk().reserve("topk", 5, 100, 5, 0.9) -@pytest.mark.redismod def test_bf_reserve(client): """Testing BF.RESERVE""" assert client.bf().reserve("bloom", 0.01, 1000) @@ -55,14 +52,11 @@ def test_bf_reserve(client): assert client.topk().reserve("topk", 5, 100, 5, 0.9) -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_create(client): assert client.tdigest().create("tDigest", 100) -# region Test Bloom Filter -@pytest.mark.redismod def test_bf_add(client): assert client.bf().create("bloom", 0.01, 1000) assert 1 == client.bf().add("bloom", "foo") @@ -75,7 +69,6 @@ def test_bf_add(client): assert [1, 0] == intlist(client.bf().mexists("bloom", "foo", "noexist")) -@pytest.mark.redismod def test_bf_insert(client): assert client.bf().create("bloom", 0.01, 1000) assert [1] == intlist(client.bf().insert("bloom", ["foo"])) @@ -86,12 +79,26 @@ def test_bf_insert(client): assert 0 == client.bf().exists("bloom", "noexist") assert [1, 0] == intlist(client.bf().mexists("bloom", "foo", "noexist")) info = client.bf().info("bloom") - assert 2 == info.insertedNum - assert 1000 == info.capacity - assert 1 == info.filterNum + assert_resp_response( + client, + 2, + info.get("insertedNum"), + info.get("Number of items inserted"), + ) + assert_resp_response( + client, + 1000, + info.get("capacity"), + info.get("Capacity"), + ) + assert_resp_response( + client, + 1, + info.get("filterNum"), + info.get("Number of filters"), + ) -@pytest.mark.redismod def test_bf_scandump_and_loadchunk(client): # Store a filter client.bf().create("myBloom", "0.0001", "1000") @@ -143,17 +150,26 @@ def do_verify(): client.bf().create("myBloom", "0.0001", "10000000") -@pytest.mark.redismod def test_bf_info(client): expansion = 4 # Store a filter client.bf().create("nonscaling", "0.0001", "1000", noScale=True) info = client.bf().info("nonscaling") - assert info.expansionRate is None + assert_resp_response( + client, + None, + info.get("expansionRate"), + info.get("Expansion rate"), + ) client.bf().create("expanding", "0.0001", "1000", expansion=expansion) info = client.bf().info("expanding") - assert info.expansionRate == 4 + assert_resp_response( + client, + 4, + info.get("expansionRate"), + info.get("Expansion rate"), + ) try: # noScale mean no expansion @@ -165,7 +181,6 @@ def test_bf_info(client): assert True -@pytest.mark.redismod def test_bf_card(client): # return 0 if the key does not exist assert client.bf().card("not_exist") == 0 @@ -180,8 +195,6 @@ def test_bf_card(client): client.bf().card("setKey") -# region Test Cuckoo Filter -@pytest.mark.redismod def test_cf_add_and_insert(client): assert client.cf().create("cuckoo", 1000) assert client.cf().add("cuckoo", "filter") @@ -196,12 +209,17 @@ def test_cf_add_and_insert(client): assert [1] == client.cf().insert("empty1", ["foo"], capacity=1000) assert [1] == client.cf().insertnx("empty2", ["bar"], capacity=1000) info = client.cf().info("captest") - assert 5 == info.insertedNum - assert 0 == info.deletedNum - assert 1 == info.filterNum + assert_resp_response( + client, 5, info.get("insertedNum"), info.get("Number of items inserted") + ) + assert_resp_response( + client, 0, info.get("deletedNum"), info.get("Number of items deleted") + ) + assert_resp_response( + client, 1, info.get("filterNum"), info.get("Number of filters") + ) -@pytest.mark.redismod def test_cf_exists_and_del(client): assert client.cf().create("cuckoo", 1000) assert client.cf().add("cuckoo", "filter") @@ -214,8 +232,6 @@ def test_cf_exists_and_del(client): assert 0 == client.cf().count("cuckoo", "filter") -# region Test Count-Min Sketch -@pytest.mark.redismod def test_cms(client): assert client.cms().initbydim("dim", 1000, 5) assert client.cms().initbyprob("prob", 0.01, 0.01) @@ -225,12 +241,12 @@ def test_cms(client): assert [10, 15] == client.cms().incrby("dim", ["foo", "bar"], [5, 15]) assert [10, 15] == client.cms().query("dim", "foo", "bar") info = client.cms().info("dim") - assert 1000 == info.width - assert 5 == info.depth - assert 25 == info.count + assert info["width"] + assert 1000 == info["width"] + assert 5 == info["depth"] + assert 25 == info["count"] -@pytest.mark.redismod @pytest.mark.onlynoncluster def test_cms_merge(client): assert client.cms().initbydim("A", 1000, 5) @@ -248,11 +264,6 @@ def test_cms_merge(client): assert [16, 15, 21] == client.cms().query("C", "foo", "bar", "baz") -# endregion - - -# region Test Top-K -@pytest.mark.redismod def test_topk(client): # test list with empty buckets assert client.topk().reserve("topk", 3, 50, 4, 0.9) @@ -326,13 +337,12 @@ def test_topk(client): assert ["A", "B", "E"] == client.topk().list("topklist") assert ["A", 4, "B", 3, "E", 3] == client.topk().list("topklist", withcount=True) info = client.topk().info("topklist") - assert 3 == info.k - assert 50 == info.width - assert 3 == info.depth - assert 0.9 == round(float(info.decay), 1) + assert 3 == info["k"] + assert 50 == info["width"] + assert 3 == info["depth"] + assert 0.9 == round(float(info["decay"]), 1) -@pytest.mark.redismod def test_topk_incrby(client): client.flushdb() assert client.topk().reserve("topk", 3, 10, 3, 1) @@ -346,8 +356,6 @@ def test_topk_incrby(client): ) -# region Test T-Digest -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_reset(client): assert client.tdigest().create("tDigest", 10) @@ -357,12 +365,14 @@ def test_tdigest_reset(client): assert client.tdigest().add("tDigest", list(range(10))) assert client.tdigest().reset("tDigest") - # assert we have 0 unmerged nodes - assert 0 == client.tdigest().info("tDigest").unmerged_nodes + # assert we have 0 unmerged + info = client.tdigest().info("tDigest") + assert_resp_response( + client, 0, info.get("unmerged_nodes"), info.get("Unmerged nodes") + ) -@pytest.mark.redismod -@pytest.mark.experimental +@pytest.mark.onlynoncluster def test_tdigest_merge(client): assert client.tdigest().create("to-tDigest", 10) assert client.tdigest().create("from-tDigest", 10) @@ -373,8 +383,10 @@ def test_tdigest_merge(client): assert client.tdigest().merge("to-tDigest", 1, "from-tDigest") # we should now have 110 weight on to-histogram info = client.tdigest().info("to-tDigest") - total_weight_to = float(info.merged_weight) + float(info.unmerged_weight) - assert 20 == total_weight_to + if is_resp2_connection(client): + assert 20 == float(info["merged_weight"]) + float(info["unmerged_weight"]) + else: + assert 20 == float(info["Merged weight"]) + float(info["Unmerged weight"]) # test override assert client.tdigest().create("from-override", 10) assert client.tdigest().create("from-override-2", 10) @@ -387,7 +399,6 @@ def test_tdigest_merge(client): assert 4.0 == client.tdigest().max("to-tDigest") -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_min_and_max(client): assert client.tdigest().create("tDigest", 100) @@ -398,7 +409,6 @@ def test_tdigest_min_and_max(client): assert 1 == client.tdigest().min("tDigest") -@pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") def test_tdigest_quantile(client): @@ -420,7 +430,6 @@ def test_tdigest_quantile(client): assert [3.0, 5.0] == client.tdigest().quantile("t-digest", 0.5, 0.8) -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_cdf(client): assert client.tdigest().create("tDigest", 100) @@ -432,7 +441,6 @@ def test_tdigest_cdf(client): assert [0.1, 0.9] == [round(x, 1) for x in res] -@pytest.mark.redismod @pytest.mark.experimental @skip_ifmodversion_lt("2.4.0", "bf") def test_tdigest_trimmed_mean(client): @@ -443,7 +451,6 @@ def test_tdigest_trimmed_mean(client): assert 4.5 == client.tdigest().trimmed_mean("tDigest", 0.4, 0.5) -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_rank(client): assert client.tdigest().create("t-digest", 500) @@ -454,7 +461,6 @@ def test_tdigest_rank(client): assert [-1, 20, 9] == client.tdigest().rank("t-digest", -20, 20, 9) -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_revrank(client): assert client.tdigest().create("t-digest", 500) @@ -464,7 +470,6 @@ def test_tdigest_revrank(client): assert [-1, 19, 9] == client.tdigest().revrank("t-digest", 21, 0, 10) -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_byrank(client): assert client.tdigest().create("t-digest", 500) @@ -476,7 +481,6 @@ def test_tdigest_byrank(client): client.tdigest().byrank("t-digest", -1)[0] -@pytest.mark.redismod @pytest.mark.experimental def test_tdigest_byrevrank(client): assert client.tdigest().create("t-digest", 500) @@ -488,8 +492,7 @@ def test_tdigest_byrevrank(client): client.tdigest().byrevrank("t-digest", -1)[0] -# @pytest.mark.redismod -# def test_pipeline(client): +# # def test_pipeline(client): # pipeline = client.bf().pipeline() # assert not client.bf().execute_command("get pipeline") # diff --git a/tests/test_cluster.py b/tests/test_cluster.py index da6a8e4bf7..ae194db3a2 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,12 +1,17 @@ import binascii import datetime +import select +import socket +import socketserver +import threading import warnings +from queue import LifoQueue, Queue from time import sleep from unittest.mock import DEFAULT, Mock, call, patch import pytest - from redis import Redis +from redis._parsers import CommandsParser from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import ( PRIMARY, @@ -17,7 +22,6 @@ RedisCluster, get_node_name, ) -from redis.commands import CommandsParser from redis.connection import BlockingConnectionPool, Connection, ConnectionPool from redis.crc import key_slot from redis.exceptions import ( @@ -38,6 +42,8 @@ from .conftest import ( _get_client, + assert_resp_response, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, @@ -52,6 +58,73 @@ ] +class ProxyRequestHandler(socketserver.BaseRequestHandler): + def recv(self, sock): + """A recv with a timeout""" + r = select.select([sock], [], [], 0.01) + if not r[0]: + return None + return sock.recv(1000) + + def handle(self): + self.server.proxy.n_connections += 1 + conn = socket.create_connection(self.server.proxy.redis_addr) + stop = False + + def from_server(): + # read from server and pass to client + while not stop: + data = self.recv(conn) + if data is None: + continue + if not data: + self.request.shutdown(socket.SHUT_WR) + return + self.request.sendall(data) + + thread = threading.Thread(target=from_server) + thread.start() + try: + while True: + # read from client and send to server + data = self.request.recv(1000) + if not data: + return + conn.sendall(data) + finally: + conn.shutdown(socket.SHUT_WR) + stop = True # for safety + thread.join() + conn.close() + + +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.server = socketserver.ThreadingTCPServer(self.addr, ProxyRequestHandler) + self.server.proxy = self + self.server.socket_reuse_address = True + self.thread = None + self.n_connections = 0 + + def start(self): + # test that we can connect to redis + s = socket.create_connection(self.redis_addr, timeout=2) + s.close() + # Start a thread with the server -- that thread will then start one + # more thread for each request + self.thread = threading.Thread(target=self.server.serve_forever) + # Exit the server thread when the main thread terminates + self.thread.daemon = True + self.thread.start() + + def close(self): + self.server.shutdown() + + @pytest.fixture() def slowlog(request, r): """ @@ -703,6 +776,8 @@ def test_not_require_full_coverage_cluster_down_error(self, r): assert all(r.cluster_delslots(missing_slot)) with pytest.raises(ClusterDownError): r.exists("foo") + except ResponseError as e: + assert "CLUSTERDOWN" in str(e) finally: try: # Add back the missing slot @@ -822,6 +897,51 @@ def raise_connection_error(): assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node + def test_address_remap(self, request, master_host): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + hostname, master_port = master_host + ports = [master_port + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports + ] + for p in proxies: + p.start() + try: + # create cluster: + r = _get_client( + RedisCluster, request, flushdb=False, address_remap=address_remap + ) + try: + assert r.ping() is True + assert r.set("byte_string", b"giraffe") + assert r.get("byte_string") == b"giraffe" + finally: + r.close() + finally: + for p in proxies: + p.close() + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + @pytest.mark.onlycluster class TestClusterRedisCommands: @@ -883,7 +1003,7 @@ def test_client_setname(self, r): node = r.get_random_node() r.client_setname("redis_py_test", target_nodes=node) client_name = r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" + assert_resp_response(r, client_name, "redis_py_test", b"redis_py_test") def test_exists(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} @@ -1040,11 +1160,25 @@ def test_cluster_shards(self, r): b"health", ] for x in cluster_shards: - assert list(x.keys()) == ["slots", "nodes"] - for node in x["nodes"]: + assert_resp_response( + r, list(x.keys()), ["slots", "nodes"], [b"slots", b"nodes"] + ) + try: + x["nodes"] + key = "nodes" + except KeyError: + key = b"nodes" + for node in x[key]: for attribute in node.keys(): assert attribute in attributes + @skip_if_server_version_lt("7.2.0") + @skip_if_redis_enterprise() + def test_cluster_myshardid(self, r): + myshardid = r.cluster_myshardid() + assert isinstance(myshardid, str) + assert len(myshardid) > 0 + @skip_if_redis_enterprise() def test_cluster_addslots(self, r): node = r.get_random_node() @@ -1291,11 +1425,18 @@ def test_cluster_replicas(self, r): def test_cluster_links(self, r): node = r.get_random_node() res = r.cluster_links(node) - links_to = sum(x.count("to") for x in res) - links_for = sum(x.count("from") for x in res) - assert links_to == links_for - for i in range(0, len(res) - 1, 2): - assert res[i][3] == res[i + 1][3] + if is_resp2_connection(r): + links_to = sum(x.count(b"to") for x in res) + links_for = sum(x.count(b"from") for x in res) + assert links_to == links_for + for i in range(0, len(res) - 1, 2): + assert res[i][3] == res[i + 1][3] + else: + links_to = len(list(filter(lambda x: x[b"direction"] == b"to", res))) + links_for = len(list(filter(lambda x: x[b"direction"] == b"from", res))) + assert links_to == links_for + for i in range(0, len(res) - 1, 2): + assert res[i][b"node"] == res[i + 1][b"node"] def test_cluster_flshslots_not_implemented(self, r): with pytest.raises(NotImplementedError): @@ -1471,7 +1612,7 @@ def test_client_trackinginfo(self, r): node = r.get_primaries()[0] res = r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @skip_if_server_version_lt("2.9.50") def test_client_pause(self, r): @@ -1633,24 +1774,68 @@ def test_cluster_renamenx(self, r): def test_cluster_blpop(self, r): r.rpush("{foo}a", "1", "2") r.rpush("{foo}b", "3", "4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None r.rpush("{foo}c", "1") - assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, r.blpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) def test_cluster_brpop(self, r): r.rpush("{foo}a", "1", "2") r.rpush("{foo}b", "3", "4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None r.rpush("{foo}c", "1") - assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, r.brpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) def test_cluster_brpoplpush(self, r): r.rpush("{foo}a", "1", "2") @@ -1723,7 +1908,13 @@ def test_cluster_zdiff(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + response = r.zdiff(["{foo}a", "{foo}b"], withscores=True) + assert_resp_response( + r, + response, + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): @@ -1731,7 +1922,8 @@ def test_cluster_zdiffstore(self, r): r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert r.zrange("{foo}out", 0, -1) == [b"a3"] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + response = r.zrange("{foo}out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): @@ -1742,31 +1934,42 @@ def test_cluster_zinter(self, r): # invalid aggregation with pytest.raises(DataError): r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MAX"), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) + assert_resp_response( + r, + r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True, aggregate="MIN"), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) + assert_resp_response( + r, + r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a3", 20.0), (b"a1", 23.0)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zinterstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1776,7 +1979,12 @@ def test_cluster_zinterstore_max(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zinterstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1786,38 +1994,98 @@ def test_cluster_zinterstore_min(self, r): r.zinterstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 2 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1.0], [b"a3", 3.0]], + ) def test_cluster_zinterstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 2 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmax(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], + ) assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + r.bzpopmax("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmin(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + r.bzpopmin("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zrangestore(self, r): @@ -1826,7 +2094,12 @@ def test_cluster_zrangestore(self, r): assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("{foo}b", "{foo}a", 1, 2) assert r.zrange("{foo}b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("{foo}b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("{foo}b", 0, 1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) # reversed order assert r.zrangestore("{foo}b", "{foo}a", 1, 2, desc=True) assert r.zrange("{foo}b", 0, -1) == [b"a1", b"a2"] @@ -1848,39 +2121,45 @@ def test_cluster_zunion(self, r): r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["{foo}a", "{foo}b", "{foo}c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) # max - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) # min - assert r.zunion( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)] + assert_resp_response( + r, + r.zunion(["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 1.0], [b"a3", 1.0], [b"a4", 4.0]], + ) # with weight - assert r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zunion({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) def test_cluster_zunionstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"]) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) def test_cluster_zunionstore_max(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) @@ -1890,12 +2169,12 @@ def test_cluster_zunionstore_max(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) def test_cluster_zunionstore_min(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) @@ -1905,24 +2184,24 @@ def test_cluster_zunionstore_min(self, r): r.zunionstore("{foo}d", ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN") == 4 ) - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) def test_cluster_zunionstore_with_weight(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) r.zadd("{foo}b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("{foo}c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("{foo}d", {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}) == 4 - assert r.zrange("{foo}d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zrange("{foo}d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) @skip_if_server_version_lt("2.8.9") def test_cluster_pfcount(self, r): @@ -2146,12 +2425,63 @@ def teardown(): user_client.hset("{cache}:0", "hkey", "hval") assert isinstance(r.acl_log(target_nodes=node), list) - assert len(r.acl_log(target_nodes=node)) == 2 + assert len(r.acl_log(target_nodes=node)) == 3 assert len(r.acl_log(count=1, target_nodes=node)) == 1 assert isinstance(r.acl_log(target_nodes=node)[0], dict) assert "client-info" in r.acl_log(count=1, target_nodes=node)[0] assert r.acl_log_reset(target_nodes=node) + def generate_lib_code(self, lib_name): + return f"""#!js api_version=1.0 name={lib_name}\n redis.registerFunction('foo', ()=>{{return 'bar'}})""" # noqa + + def try_delete_libs(self, r, *lib_names): + for lib_name in lib_names: + try: + r.tfunction_delete(lib_name) + except Exception: + pass + + @skip_if_server_version_lt("7.1.140") + def test_tfunction_load_delete(self, r): + r.gears_refresh_cluster() + self.try_delete_libs(r, "lib1") + lib_code = self.generate_lib_code("lib1") + assert r.tfunction_load(lib_code) + assert r.tfunction_delete("lib1") + + @skip_if_server_version_lt("7.1.140") + def test_tfunction_list(self, r): + r.gears_refresh_cluster() + self.try_delete_libs(r, "lib1", "lib2", "lib3") + assert r.tfunction_load(self.generate_lib_code("lib1")) + assert r.tfunction_load(self.generate_lib_code("lib2")) + assert r.tfunction_load(self.generate_lib_code("lib3")) + + # test error thrown when verbose > 4 + with pytest.raises(DataError): + assert r.tfunction_list(verbose=8) + + functions = r.tfunction_list(verbose=1) + assert len(functions) == 3 + + expected_names = [b"lib1", b"lib2", b"lib3"] + actual_names = [functions[0][13], functions[1][13], functions[2][13]] + + assert sorted(expected_names) == sorted(actual_names) + assert r.tfunction_delete("lib1") + assert r.tfunction_delete("lib2") + assert r.tfunction_delete("lib3") + + @skip_if_server_version_lt("7.1.140") + def test_tfcall(self, r): + r.gears_refresh_cluster() + self.try_delete_libs(r, "lib1") + assert r.tfunction_load(self.generate_lib_code("lib1")) + assert r.tfcall("lib1", "foo") == b"bar" + assert r.tfcall_async("lib1", "foo") == b"bar" + + assert r.tfunction_delete("lib1") + @pytest.mark.onlycluster class TestNodesManager: @@ -2511,6 +2841,18 @@ def test_connection_pool_class(self, connection_pool_class): node.redis_connection.connection_pool, connection_pool_class ) + @pytest.mark.parametrize("queue_class", [Queue, LifoQueue]) + def test_allow_custom_queue_class(self, queue_class): + rc = get_mocked_redis_client( + url="redis://my@DNS.com:7000", + cluster_slots=default_cluster_slots, + connection_pool_class=BlockingConnectionPool, + queue_class=queue_class, + ) + + for node in rc.nodes_manager.nodes_cache.values(): + assert node.redis_connection.connection_pool.queue_class == queue_class + @pytest.mark.onlycluster class TestClusterPubSubObject: @@ -2703,6 +3045,25 @@ def test_multi_delete_unsupported(self, r): with pytest.raises(RedisClusterException): pipe.delete("a", "b") + def test_unlink_single(self, r): + """ + Test a single unlink operation + """ + r["a"] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.unlink("a") + assert pipe.execute() == [1] + + def test_multi_unlink_unsupported(self, r): + """ + Test that multi unlink operation is unsupported + """ + with r.pipeline(transaction=False) as pipe: + r["a"] = 1 + r["b"] = 2 + with pytest.raises(RedisClusterException): + pipe.unlink("a", "b") + def test_brpoplpush_disabled(self, r): """ Test that brpoplpush is disabled for ClusterPipeline @@ -2913,7 +3274,12 @@ def test_pipeline_readonly(self, r): with r.pipeline() as readonly_pipe: readonly_pipe.get("foo71").zrange("foo88", 0, 5, withscores=True) - assert readonly_pipe.execute() == [b"a1", [(b"z1", 1.0), (b"z2", 4)]] + assert_resp_response( + r, + readonly_pipe.execute(), + [b"a1", [(b"z1", 1.0), (b"z2", 4)]], + [b"a1", [[b"z1", 1.0], [b"z2", 4.0]]], + ) def test_moved_redirection_on_slave_with_default(self, r): """ diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 6c3ede9cdf..e3b44a147f 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,8 +1,11 @@ import pytest +from redis._parsers import CommandsParser -from redis.commands import CommandsParser - -from .conftest import skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + assert_resp_response, + skip_if_redis_enterprise, + skip_if_server_version_lt, +) class TestCommandsParser: @@ -51,13 +54,40 @@ def test_get_moveable_keys(self, r): ] args7 = ["MIGRATE", "192.168.1.34", 6379, "key1", 0, 5000] - assert sorted(commands_parser.get_keys(r, *args1)) == ["key1", "key2"] - assert sorted(commands_parser.get_keys(r, *args2)) == ["mystream", "writers"] - assert sorted(commands_parser.get_keys(r, *args3)) == ["out", "zset1", "zset2"] - assert sorted(commands_parser.get_keys(r, *args4)) == ["Sicily", "out"] - assert sorted(commands_parser.get_keys(r, *args5)) == ["foo"] - assert sorted(commands_parser.get_keys(r, *args6)) == ["key1", "key2", "key3"] - assert sorted(commands_parser.get_keys(r, *args7)) == ["key1"] + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args1)), + ["key1", "key2"], + [b"key1", b"key2"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args2)), + ["mystream", "writers"], + [b"mystream", b"writers"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args3)), + ["out", "zset1", "zset2"], + [b"out", b"zset1", b"zset2"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args4)), + ["Sicily", "out"], + [b"Sicily", b"out"], + ) + assert sorted(commands_parser.get_keys(r, *args5)) in [["foo"], [b"foo"]] + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args6)), + ["key1", "key2", "key3"], + [b"key1", b"key2", b"key3"], + ) + assert_resp_response( + r, sorted(commands_parser.get_keys(r, *args7)), ["key1"], [b"key1"] + ) # A bug in redis<7.0 causes this to fail: https://github.com/redis/redis/issues/9493 @skip_if_server_version_lt("7.0.0") diff --git a/tests/test_commands.py b/tests/test_commands.py index 94249e9419..055aa3bf9f 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,18 +1,29 @@ import binascii import datetime import re +import threading import time +from asyncio import CancelledError from string import ascii_letters from unittest import mock +from unittest.mock import patch import pytest - import redis from redis import exceptions -from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + parse_info, +) +from redis.client import EMPTY_RESPONSE, NEVER_DECODE from .conftest import ( _get_client, + assert_resp_response, + assert_resp_response_in, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_gte, skip_if_server_version_lt, @@ -55,17 +66,23 @@ class TestResponseCallbacks: "Tests for the response callback system" def test_response_callbacks(self, r): - assert r.response_callbacks == redis.Redis.RESPONSE_CALLBACKS - assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + callbacks = _RedisCallbacks + if is_resp2_connection(r): + callbacks.update(_RedisCallbacksRESP2) + else: + callbacks.update(_RedisCallbacksRESP3) + assert r.response_callbacks == callbacks + assert id(r.response_callbacks) != id(_RedisCallbacks) r.set_response_callback("GET", lambda x: "static") r["a"] = "foo" assert r["a"] == "static" def test_case_insensitive_command_names(self, r): - assert r.response_callbacks["del"] == r.response_callbacks["DEL"] + assert r.response_callbacks["ping"] == r.response_callbacks["PING"] class TestRedisCommands: + @pytest.mark.onlynoncluster @skip_if_redis_enterprise() def test_auth(self, r, request): # sending an AUTH command before setting a user/password on the @@ -100,7 +117,6 @@ def teardown(): # connection field is not set in Redis Cluster, but that's ok # because the problem discussed above does not apply to Redis Cluster pass - r.auth(temp_pass) r.config_set("requirepass", "") r.acl_deluser(username) @@ -126,13 +142,13 @@ def test_command_on_invalid_key_type(self, r): def test_acl_cat_no_category(self, r): categories = r.acl_cat() assert isinstance(categories, list) - assert "read" in categories + assert "read" in categories or b"read" in categories @skip_if_server_version_lt("6.0.0") def test_acl_cat_with_category(self, r): commands = r.acl_cat("read") assert isinstance(commands, list) - assert "get" in commands + assert "get" in commands or b"get" in commands @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() @@ -146,9 +162,8 @@ def teardown(): r.acl_setuser(username, keys=["*"], commands=["+set"]) assert r.acl_dryrun(username, "set", "key", "value") == b"OK" - assert r.acl_dryrun(username, "get", "key").startswith( - b"This user has no permissions to run the" - ) + no_permissions_message = b"user has no permissions to run the" + assert no_permissions_message in r.acl_dryrun(username, "get", "key") @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -179,7 +194,7 @@ def teardown(): @skip_if_redis_enterprise() def test_acl_genpass(self, r): password = r.acl_genpass() - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) with pytest.raises(exceptions.DataError): r.acl_genpass("value") @@ -187,11 +202,12 @@ def test_acl_genpass(self, r): r.acl_genpass(5555) r.acl_genpass(555) - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_acl_getuser_setuser(self, r, request): + r.flushall() username = "redis-py-user" def teardown(): @@ -226,19 +242,19 @@ def teardown(): enabled=True, reset=True, passwords=["+pass1", "+pass2"], - categories=["+set", "+@hash", "-geo"], + categories=["+set", "+@hash", "-@geo"], commands=["+get", "+mget", "-hset"], keys=["cache:*", "objects:*"], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"+@hash", "+@set", "-@all", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 - # test reset=False keeps existing ACL and applies new ACL on top + # # test reset=False keeps existing ACL and applies new ACL on top assert r.acl_setuser( username, enabled=True, @@ -257,14 +273,13 @@ def teardown(): keys=["objects:*"], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} assert set(acl["commands"]) == {"+get", "+mget"} assert acl["enabled"] is True assert "on" in acl["flags"] assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 - # test removal of passwords + # # test removal of passwords assert r.acl_setuser( username, enabled=True, reset=True, passwords=["+pass1", "+pass2"] ) @@ -272,7 +287,7 @@ def teardown(): assert r.acl_setuser(username, enabled=True, passwords=["-pass2"]) assert len(r.acl_getuser(username)["passwords"]) == 1 - # Resets and tests that hashed passwords are set properly. + # # Resets and tests that hashed passwords are set properly. hashed_password = ( "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8" ) @@ -296,7 +311,7 @@ def teardown(): ) assert len(r.acl_getuser(username)["passwords"]) == 1 - # test selectors + # # test selectors assert r.acl_setuser( username, enabled=True, @@ -309,16 +324,19 @@ def teardown(): selectors=[("+set", "%W~app*")], ) acl = r.acl_getuser(username) - assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["categories"]) == {"+@hash", "+@set", "-@all", "-@geo"} assert set(acl["commands"]) == {"+get", "+mget", "-hset"} assert acl["enabled"] is True assert "on" in acl["flags"] assert set(acl["keys"]) == {"~cache:*", "~objects:*"} assert len(acl["passwords"]) == 2 assert set(acl["channels"]) == {"&message:*"} - assert acl["selectors"] == [ - ["commands", "-@all +set", "keys", "%W~app*", "channels", ""] - ] + assert_resp_response( + r, + acl["selectors"], + [["commands", "-@all +set", "keys", "%W~app*", "channels", ""]], + [{"commands": "-@all +set", "keys": "%W~app*", "channels": ""}], + ) @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): @@ -330,6 +348,7 @@ def test_acl_help(self, r): @skip_if_redis_enterprise() def test_acl_list(self, r, request): username = "redis-py-user" + start = r.acl_list() def teardown(): r.acl_deluser(username) @@ -338,7 +357,7 @@ def teardown(): assert r.acl_setuser(username, enabled=False, reset=True) users = r.acl_list() - assert len(users) == 2 + assert len(users) == len(start) + 1 @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -377,11 +396,16 @@ def teardown(): user_client.hset("cache:0", "hkey", "hval") assert isinstance(r.acl_log(), list) - assert len(r.acl_log()) == 2 + assert len(r.acl_log()) == 3 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - assert "client-info" in r.acl_log(count=1)[0] - assert r.acl_log_reset() + expected = r.acl_log(count=1)[0] + assert_resp_response_in( + r, + "client-info", + expected, + expected.keys(), + ) @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() @@ -431,7 +455,7 @@ def test_acl_users(self, r): @skip_if_server_version_lt("6.0.0") def test_acl_whoami(self, r): username = r.acl_whoami() - assert isinstance(username, str) + assert isinstance(username, (str, bytes)) @pytest.mark.onlynoncluster def test_client_list(self, r): @@ -486,7 +510,7 @@ def test_client_id(self, r): def test_client_trackinginfo(self, r): res = r.client_trackinginfo() assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @@ -528,7 +552,27 @@ def test_client_getname(self, r): @skip_if_server_version_lt("2.6.9") def test_client_setname(self, r): assert r.client_setname("redis_py_test") - assert r.client_getname() == "redis_py_test" + assert_resp_response(r, r.client_getname(), "redis_py_test", b"redis_py_test") + + @skip_if_server_version_lt("7.2.0") + def test_client_setinfo(self, r: redis.Redis): + r.ping() + info = r.client_info() + assert info["lib-name"] == "redis-py" + assert info["lib-ver"] == redis.__version__ + assert r.client_setinfo("lib-name", "test") + assert r.client_setinfo("lib-ver", "123") + info = r.client_info() + assert info["lib-name"] == "test" + assert info["lib-ver"] == "123" + r2 = redis.Redis(lib_name="test2", lib_version="1234") + info = r2.client_info() + assert info["lib-name"] == "test2" + assert info["lib-ver"] == "1234" + r3 = redis.Redis(lib_name=None, lib_version=None) + info = r3.client_info() + assert info["lib-name"] == "" + assert info["lib-ver"] == "" @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.9") @@ -693,11 +737,33 @@ def test_client_no_evict(self, r): with pytest.raises(TypeError): r.client_no_evict() + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.2.0") + def test_client_no_touch(self, r): + assert r.client_no_touch("ON") == b"OK" + assert r.client_no_touch("OFF") == b"OK" + with pytest.raises(TypeError): + r.client_no_touch() + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.2.0") + def test_waitaof(self, r): + # must return a list of 2 elements + assert len(r.waitaof(0, 0, 0)) == 2 + assert len(r.waitaof(1, 0, 0)) == 2 + assert len(r.waitaof(1, 0, 1000)) == 2 + + # value is out of range, value must between 0 and 1 + with pytest.raises(exceptions.ResponseError): + r.waitaof(2, 0, 0) + with pytest.raises(exceptions.ResponseError): + r.waitaof(-1, 0, 0) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("3.2.0") def test_client_reply(self, r, r_timeout): assert r_timeout.client_reply("ON") == b"OK" - with pytest.raises(exceptions.TimeoutError): + with pytest.raises(exceptions.RedisError): r_timeout.client_reply("OFF") r_timeout.client_reply("SKIP") @@ -809,7 +875,7 @@ def test_lolwut(self, r): @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise() def test_reset(self, r): - assert r.reset() == "RESET" + assert_resp_response(r, r.reset(), "RESET", b"RESET") def test_object(self, r): r["a"] = "foo" @@ -861,6 +927,8 @@ def test_slowlog_get(self, r, slowlog): # make sure other attributes are typed correctly assert isinstance(slowlog[0]["start_time"], int) assert isinstance(slowlog[0]["duration"], int) + assert isinstance(slowlog[0]["client_address"], bytes) + assert isinstance(slowlog[0]["client_name"], bytes) # Mock result if we didn't get slowlog complexity info. if "complexity" not in slowlog[0]: @@ -1120,8 +1188,12 @@ def test_lcs(self, r): r.mset({"foo": "ohmytext", "bar": "mynewtext"}) assert r.lcs("foo", "bar") == b"mytext" assert r.lcs("foo", "bar", len=True) == 6 - result = [b"matches", [[[4, 7], [5, 8]]], b"len", 6] - assert r.lcs("foo", "bar", idx=True, minmatchlen=3) == result + assert_resp_response( + r, + r.lcs("foo", "bar", idx=True, minmatchlen=3), + [b"matches", [[[4, 7], [5, 8]]], b"len", 6], + {b"matches": [[[4, 7], [5, 8]]], b"len": 6}, + ) with pytest.raises(redis.ResponseError): assert r.lcs("foo", "bar", len=True, idx=True) @@ -1535,7 +1607,7 @@ def test_hrandfield(self, r): assert r.hrandfield("key") is not None assert len(r.hrandfield("key", 2)) == 2 # with values - assert len(r.hrandfield("key", 2, True)) == 4 + assert_resp_response(r, len(r.hrandfield("key", 2, withvalues=True)), 4, 2) # without duplications assert len(r.hrandfield("key", 10)) == 5 # with duplications @@ -1688,17 +1760,26 @@ def test_stralgo_lcs(self, r): assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels assert r.stralgo("LCS", value1, value2, len=True) == len(res) - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True), + {"len": len(res), "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]]}, + ) + assert_resp_response( + r, + r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]]}, + ) + assert_resp_response( + r, + r.stralgo( + "LCS", value1, value2, idx=True, withmatchlen=True, minmatchlen=4 + ), + {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]}, + {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]}, + ) @skip_if_server_version_lt("6.0.0") @skip_if_server_version_gte("7.0.0") @@ -1730,6 +1811,57 @@ def test_substr(self, r): assert r.substr("a", 3, 5) == b"345" assert r.substr("a", 3, -2) == b"345678" + def generate_lib_code(self, lib_name): + return f"""#!js api_version=1.0 name={lib_name}\n redis.registerFunction('foo', ()=>{{return 'bar'}})""" # noqa + + def try_delete_libs(self, r, *lib_names): + for lib_name in lib_names: + try: + r.tfunction_delete(lib_name) + except Exception: + pass + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.1.140") + def test_tfunction_load_delete(self, r): + self.try_delete_libs(r, "lib1") + lib_code = self.generate_lib_code("lib1") + assert r.tfunction_load(lib_code) + assert r.tfunction_delete("lib1") + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.1.140") + def test_tfunction_list(self, r): + self.try_delete_libs(r, "lib1", "lib2", "lib3") + assert r.tfunction_load(self.generate_lib_code("lib1")) + assert r.tfunction_load(self.generate_lib_code("lib2")) + assert r.tfunction_load(self.generate_lib_code("lib3")) + + # test error thrown when verbose > 4 + with pytest.raises(redis.exceptions.DataError): + assert r.tfunction_list(verbose=8) + + functions = r.tfunction_list(verbose=1) + assert len(functions) == 3 + + expected_names = [b"lib1", b"lib2", b"lib3"] + actual_names = [functions[0][13], functions[1][13], functions[2][13]] + + assert sorted(expected_names) == sorted(actual_names) + assert r.tfunction_delete("lib1") + assert r.tfunction_delete("lib2") + assert r.tfunction_delete("lib3") + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.1.140") + def test_tfcall(self, r): + self.try_delete_libs(r, "lib1") + assert r.tfunction_load(self.generate_lib_code("lib1")) + assert r.tfcall("lib1", "foo") == b"bar" + assert r.tfcall_async("lib1", "foo") == b"bar" + + assert r.tfunction_delete("lib1") + def test_ttl(self, r): r["a"] = "1" assert r.expire("a", 10) @@ -1761,25 +1893,41 @@ def test_type(self, r): def test_blpop(self, r): r.rpush("a", "1", "2") r.rpush("b", "3", "4") - assert r.blpop(["b", "a"], timeout=1) == (b"b", b"3") - assert r.blpop(["b", "a"], timeout=1) == (b"b", b"4") - assert r.blpop(["b", "a"], timeout=1) == (b"a", b"1") - assert r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) assert r.blpop(["b", "a"], timeout=1) is None r.rpush("c", "1") - assert r.blpop("c", timeout=1) == (b"c", b"1") + assert_resp_response(r, r.blpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"]) @pytest.mark.onlynoncluster def test_brpop(self, r): r.rpush("a", "1", "2") r.rpush("b", "3", "4") - assert r.brpop(["b", "a"], timeout=1) == (b"b", b"4") - assert r.brpop(["b", "a"], timeout=1) == (b"b", b"3") - assert r.brpop(["b", "a"], timeout=1) == (b"a", b"2") - assert r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) assert r.brpop(["b", "a"], timeout=1) is None r.rpush("c", "1") - assert r.brpop("c", timeout=1) == (b"c", b"1") + assert_resp_response(r, r.brpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"]) @pytest.mark.onlynoncluster def test_brpoplpush(self, r): @@ -2147,8 +2295,12 @@ def test_spop_multi_value(self, r): for value in values: assert value in s - - assert r.spop("a", 1) == list(set(s) - set(values)) + assert_resp_response( + r, + r.spop("a", 1), + list(set(s) - set(values)), + set(s) - set(values), + ) def test_srandmember(self, r): s = [b"1", b"2", b"3"] @@ -2199,11 +2351,12 @@ def test_script_debug(self, r): def test_zadd(self, r): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} r.zadd("a", mapping) - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -2220,17 +2373,32 @@ def test_zadd(self, r): def test_zadd_nx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) def test_zadd_xx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a1", 99.0)], + [[b"a1", 99.0]], + ) def test_zadd_ch(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a2", 2.0), (b"a1", 99.0)] + assert_resp_response( + r, + r.zrange("a", 0, -1, withscores=True), + [(b"a2", 2.0), (b"a1", 99.0)], + [[b"a2", 2.0], [b"a1", 99.0]], + ) def test_zadd_incr(self, r): assert r.zadd("a", {"a1": 1}) == 1 @@ -2278,7 +2446,12 @@ def test_zdiff(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiff(["a", "b"]) == [b"a3"] - assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] + assert_resp_response( + r, + r.zdiff(["a", "b"], withscores=True), + [b"a3", b"3"], + [[b"a3", 3.0]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @@ -2287,7 +2460,12 @@ def test_zdiffstore(self, r): r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiffstore("out", ["a", "b"]) assert r.zrange("out", 0, -1) == [b"a3"] - assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] + assert_resp_response( + r, + r.zrange("out", 0, -1, withscores=True), + [(b"a3", 3.0)], + [[b"a3", 3.0]], + ) def test_zincrby(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2313,22 +2491,33 @@ def test_zinter(self, r): with pytest.raises(exceptions.DataError): r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zinter(["a", "b", "c"], withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a3", 1), - ] + assert_resp_response( + r, + r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a3", 1)], + [[b"a1", 1], [b"a3", 1]], + ) # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -2345,7 +2534,12 @@ def test_zinterstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"]) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 8), (b"a1", 9)], + [[b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): @@ -2353,7 +2547,12 @@ def test_zinterstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 5), (b"a1", 6)], + [[b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): @@ -2361,7 +2560,12 @@ def test_zinterstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a3", 3)], + [[b"a1", 1], [b"a3", 3]], + ) @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): @@ -2369,23 +2573,36 @@ def test_zinterstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a3", 20), (b"a1", 23)], + [[b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmax("a") == [(b"a3", 3)] - + assert_resp_response(r, r.zpopmax("a"), [(b"a3", 3)], [b"a3", 3.0]) # with count - assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + assert_resp_response( + r, + r.zpopmax("a", count=2), + [(b"a2", 2), (b"a1", 1)], + [[b"a2", 2], [b"a1", 1]], + ) @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmin("a") == [(b"a1", 1)] - + assert_resp_response(r, r.zpopmin("a"), [(b"a1", 1)], [b"a1", 1.0]) # with count - assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zpopmin("a", count=2), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): @@ -2393,7 +2610,12 @@ def test_zrandemember(self, r): assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 # with scores - assert len(r.zrandmember("a", 2, True)) == 4 + assert_resp_response( + r, + len(r.zrandmember("a", 2, withscores=True)), + 4, + 2, + ) # without duplications assert len(r.zrandmember("a", 10)) == 5 # with duplications @@ -2404,49 +2626,86 @@ def test_zrandemember(self, r): def test_bzpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2}) r.zadd("b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"b", b"b2", 20), [b"b", b"b2", 20] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"b", b"b1", 10), [b"b", b"b1", 10] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"a", b"a2", 2), [b"a", b"a2", 2] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"a", b"a1", 1), [b"a", b"a1", 1] + ) assert r.bzpopmax(["b", "a"], timeout=1) is None r.zadd("c", {"c1": 100}) - assert r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, r.bzpopmax("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("4.9.0") def test_bzpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2}) r.zadd("b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"b", b"b1", 10), [b"b", b"b1", 10] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"b", b"b2", 20), [b"b", b"b2", 20] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"a", b"a1", 1), [b"a", b"a1", 1] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"a", b"a2", 2), [b"a", b"a2", 2] + ) assert r.bzpopmin(["b", "a"], timeout=1) is None r.zadd("c", {"c1": 100}) - assert r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, r.bzpopmin("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_zmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - assert r.zmpop("2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.zmpop("2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - assert r.zmpop("2", ["b", "a"], max=True) == [b"b", [[b"b1", b"10"]]] + assert_resp_response( + r, + r.zmpop("2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") def test_bzmpop(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - res = [b"a", [[b"a1", b"1"], [b"a2", b"2"]]] - assert r.bzmpop(1, "2", ["b", "a"], min=True, count=2) == res + assert_resp_response( + r, + r.bzmpop(1, "2", ["b", "a"], min=True, count=2), + [b"a", [[b"a1", b"1"], [b"a2", b"2"]]], + [b"a", [[b"a1", 1.0], [b"a2", 2.0]]], + ) with pytest.raises(redis.DataError): r.bzmpop(1, "2", ["b", "a"], count=2) r.zadd("b", {"b1": 10, "ab": 9, "b3": 8}) - res = [b"b", [[b"b1", b"10"]]] - assert r.bzmpop(0, "2", ["b", "a"], max=True) == res + assert_resp_response( + r, + r.bzmpop(0, "2", ["b", "a"], max=True), + [b"b", [[b"b1", b"10"]]], + [b"b", [[b"b1", 10.0]]], + ) assert r.bzmpop(1, "2", ["foo", "bar"], max=True) is None def test_zrange(self, r): @@ -2457,14 +2716,24 @@ def test_zrange(self, r): assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] + assert_resp_response( + r, + r.zrange("a", 0, 1, withscores=True), + [(b"a1", 1.0), (b"a2", 2.0)], + [[b"a1", 1.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0)], + [[b"a2", 2.0], [b"a3", 3.0]], + ) - # custom score function - assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + # # custom score function + # assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -2496,14 +2765,20 @@ def test_zrange_params(self, r): b"a3", b"a2", ] - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrange( - "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + assert_resp_response( + r, + r.zrange("a", 2, 4, byscore=True, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ), + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) # rev assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @@ -2516,7 +2791,12 @@ def test_zrangestore(self, r): assert r.zrange("b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("b", "a", 1, 2) assert r.zrange("b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + assert_resp_response( + r, + r.zrange("b", 0, -1, withscores=True), + [(b"a2", 2), (b"a3", 3)], + [[b"a2", 2], [b"a3", 3]], + ) # reversed order assert r.zrangestore("b", "a", 1, 2, desc=True) assert r.zrange("b", 0, -1) == [b"a1", b"a2"] @@ -2551,16 +2831,18 @@ def test_zrangebyscore(self, r): # slicing with start/num assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True), + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) + assert_resp_response( + r, + r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int), + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2568,6 +2850,15 @@ def test_zrank(self, r): assert r.zrank("a", "a2") == 1 assert r.zrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + def test_zrank_withscore(self, r: redis.Redis): + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrank("a", "a1") == 0 + assert r.rank("a", "a2") == 1 + assert r.zrank("a", "a6") is None + assert r.zrank("a", "a3", withscore=True) == [2, "3"] + assert r.zrank("a", "a6", withscore=True) is None + def test_zrem(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert r.zrem("a", "a2") == 1 @@ -2608,32 +2899,45 @@ def test_zrevrange(self, r): assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [(b"a3", 3.0), (b"a2", 2.0)] - assert r.zrevrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a1", 1.0)] + assert_resp_response( + r, + r.zrevrange("a", 0, 1, withscores=True), + [(b"a3", 3.0), (b"a2", 2.0)], + [[b"a3", 3.0], [b"a2", 2.0]], + ) + assert_resp_response( + r, + r.zrevrange("a", 1, 2, withscores=True), + [(b"a2", 2.0), (b"a1", 1.0)], + [[b"a2", 2.0], [b"a1", 1.0]], + ) - # custom score function - assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + # # custom score function + # assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a3", 3.0), + # (b"a2", 2.0), + # ] def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) assert r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] # slicing with start/num assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] + # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) # custom score function - assert r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) == [ - (b"a4", 4), - (b"a3", 3), - (b"a2", 2), - ] + assert_resp_response( + r, + r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int), + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2641,6 +2945,15 @@ def test_zrevrank(self, r): assert r.zrevrank("a", "a2") == 3 assert r.zrevrank("a", "a6") is None + @skip_if_server_version_lt("7.2.0") + def test_zrevrank_withscore(self, r): + r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert r.zrevrank("a", "a1") == 4 + assert r.zrevrank("a", "a2") == 3 + assert r.zrevrank("a", "a6") is None + assert r.zrevrank("a", "a3", withscore=True) == [2, "3"] + assert r.zrevrank("a", "a6", withscore=True) is None + def test_zscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert r.zscore("a", "a1") == 1.0 @@ -2655,33 +2968,33 @@ def test_zunion(self, r): r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["a", "b", "c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a2", 1), - (b"a3", 1), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True), + [(b"a1", 1), (b"a2", 1), (b"a3", 1), (b"a4", 4)], + [[b"a1", 1], [b"a2", 1], [b"a3", 1], [b"a4", 4]], + ) # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): @@ -2689,12 +3002,12 @@ def test_zunionstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"]) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 3), (b"a4", 4), (b"a3", 8), (b"a1", 9)], + [[b"a2", 3], [b"a4", 4], [b"a3", 8], [b"a1", 9]], + ) @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): @@ -2702,12 +3015,12 @@ def test_zunionstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 2), (b"a4", 4), (b"a3", 5), (b"a1", 6)], + [[b"a2", 2], [b"a4", 4], [b"a3", 5], [b"a1", 6]], + ) @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): @@ -2715,12 +3028,12 @@ def test_zunionstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a1", 1), (b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a1", 1], [b"a2", 2], [b"a3", 3], [b"a4", 4]], + ) @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): @@ -2728,12 +3041,12 @@ def test_zunionstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + assert_resp_response( + r, + r.zrange("d", 0, -1, withscores=True), + [(b"a2", 5), (b"a4", 12), (b"a3", 20), (b"a1", 23)], + [[b"a2", 5], [b"a4", 12], [b"a3", 20], [b"a1", 23]], + ) @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): @@ -3248,11 +3561,12 @@ def test_geohash(self, r): "place2", ) r.geoadd("barcelona", values) - assert r.geohash("barcelona", "place1", "place2", "place3") == [ - "sp3e9yg3kd0", - "sp3e9cbc3t0", - None, - ] + assert_resp_response( + r, + r.geohash("barcelona", "place1", "place2", "place3"), + ["sp3e9yg3kd0", "sp3e9cbc3t0", None], + [b"sp3e9yg3kd0", b"sp3e9cbc3t0", None], + ) @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") @@ -3264,10 +3578,18 @@ def test_geopos(self, r): ) r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert r.geopos("barcelona", "place1", "place2") == [ - (2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099), - ] + assert_resp_response( + r, + r.geopos("barcelona", "place1", "place2"), + [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ], + [ + [2.19093829393386841, 41.43379028184083523], + [2.18737632036209106, 41.40634178640635099], + ], + ) @skip_if_server_version_lt("4.0.0") def test_geopos_no_value(self, r): @@ -3508,6 +3830,12 @@ def test_geosearchstore_dist(self, r): # instead of save the geo score, the distance is saved. assert r.zscore("places_barcelona", "place1") == 88.05060698409301 + @skip_if_server_version_lt("3.2.0") + def test_georadius_Issue2609(self, r): + # test for issue #2609 (Geo search functions don't work with execute_command) + r.geoadd(name="my-key", values=[1, 2, "data"]) + assert r.execute_command("GEORADIUS", "my-key", 1, 2, 400, "m") == [b"data"] + @skip_if_server_version_lt("3.2.0") def test_georadius(self, r): values = (2.1909389952632, 41.433791470673, "place1") + ( @@ -3801,7 +4129,7 @@ def test_xadd_explicit_ms(self, r: redis.Redis): ms = message_id[: message_id.index(b"-")] assert ms == b"9999999999999999999" - @skip_if_server_version_lt("6.2.0") + @skip_if_server_version_lt("7.0.0") def test_xautoclaim(self, r): stream = "stream" group = "group" @@ -3816,7 +4144,7 @@ def test_xautoclaim(self, r): # trying to claim a message that isn't already pending doesn't # do anything response = r.xautoclaim(stream, group, consumer2, min_idle_time=0) - assert response == [b"0-0", []] + assert response == [b"0-0", [], []] # read the group as consumer1 to initially claim the messages r.xreadgroup(group, consumer1, streams={stream: ">"}) @@ -4060,7 +4388,7 @@ def test_xgroup_setid(self, r): ] assert r.xinfo_groups(stream) == expected - @skip_if_server_version_lt("5.0.0") + @skip_if_server_version_lt("7.2.0") def test_xinfo_consumers(self, r): stream = "stream" group = "group" @@ -4076,8 +4404,8 @@ def test_xinfo_consumers(self, r): info = r.xinfo_consumers(stream, group) assert len(info) == 2 expected = [ - {"name": consumer1.encode(), "pending": 1}, - {"name": consumer2.encode(), "pending": 2}, + {"name": consumer1.encode(), "pending": 1, "inactive": 2}, + {"name": consumer2.encode(), "pending": 2, "inactive": 2}, ] # we can't determine the idle time, so just make sure it's an int @@ -4108,7 +4436,12 @@ def test_xinfo_stream_full(self, r): info = r.xinfo_stream(stream, full=True) assert info["length"] == 1 - assert m1 in info["entries"] + assert_resp_response_in( + r, + m1, + info["entries"], + info["entries"].keys(), + ) assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4249,25 +4582,39 @@ def test_xread(self, r): m1 = r.xadd(stream, {"foo": "bar"}) m2 = r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + stream_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert r.xread(streams={stream: 0}) == expected + assert_resp_response( + r, + r.xread(streams={stream: 0}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert r.xread(streams={stream: 0}, count=1) == expected + assert_resp_response( + r, + r.xread(streams={stream: 0}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) - expected = [[stream.encode(), [get_stream_message(r, stream, m2)]]] + expected_entries = [get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert r.xread(streams={stream: m1}) == expected + assert_resp_response( + r, + r.xread(streams={stream: m1}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) # xread starting at the last message returns an empty list - assert r.xread(streams={stream: m2}) == [] + assert_resp_response(r, r.xread(streams={stream: m2}), [], {}) @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): @@ -4278,21 +4625,32 @@ def test_xreadgroup(self, r): m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + stream_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) == expected + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: ">"}, count=1), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) r.xgroup_destroy(stream, group) @@ -4300,27 +4658,37 @@ def test_xreadgroup(self, r): # will only find messages added after this r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + assert_resp_response( + r, r.xreadgroup(group, consumer, streams={stream: ">"}), [], {} + ) # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") - assert ( - len(r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True)[0][1]) - == 2 - ) - # now there should be nothing pending - assert len(r.xreadgroup(group, consumer, streams={stream: "0"})[0][1]) == 0 + res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[stream_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[stream_name][0]) == 0 r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - assert r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + assert_resp_response( + r, + r.xreadgroup(group, consumer, streams={stream: "0"}), + [[stream_name, expected_entries]], + {stream_name: [expected_entries]}, + ) @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): @@ -4586,7 +4954,7 @@ def test_command_list(self, r: redis.Redis): @skip_if_redis_enterprise() def test_command_getkeys(self, r): res = r.command_getkeys("MSET", "a", "b", "c", "d", "e", "f") - assert res == ["a", "c", "e"] + assert_resp_response(r, res, ["a", "c", "e"], [b"a", b"c", b"e"]) res = r.command_getkeys( "EVAL", '"not consulted"', @@ -4599,7 +4967,9 @@ def test_command_getkeys(self, r): "arg3", "argN", ) - assert res == ["key1", "key2", "key3"] + assert_resp_response( + r, res, ["key1", "key2", "key3"], [b"key1", b"key2", b"key3"] + ) @skip_if_server_version_lt("2.8.13") def test_command(self, r): @@ -4613,12 +4983,17 @@ def test_command(self, r): @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() def test_command_getkeysandflags(self, r: redis.Redis): - res = [ - [b"mylist1", [b"RW", b"access", b"delete"]], - [b"mylist2", [b"RW", b"insert"]], - ] - assert res == r.command_getkeysandflags( - "LMOVE", "mylist1", "mylist2", "left", "left" + assert_resp_response( + r, + r.command_getkeysandflags("LMOVE", "mylist1", "mylist2", "left", "left"), + [ + [b"mylist1", [b"RW", b"access", b"delete"]], + [b"mylist2", [b"RW", b"insert"]], + ], + [ + [b"mylist1", {b"RW", b"access", b"delete"}], + [b"mylist2", {b"RW", b"insert"}], + ], ) @pytest.mark.onlynoncluster @@ -4711,9 +5086,12 @@ def test_shutdown_with_params(self, r: redis.Redis): r.execute_command.assert_called_with("SHUTDOWN", "ABORT") @pytest.mark.replica + @pytest.mark.xfail(strict=False) @skip_if_server_version_lt("2.8.0") @skip_if_redis_enterprise() def test_sync(self, r): + r.flushdb() + time.sleep(1) r2 = redis.Redis(port=6380, decode_responses=False) res = r2.sync() assert b"REDIS" in res @@ -4726,6 +5104,38 @@ def test_psync(self, r): res = r2.psync(r2.client_id(), 1) assert b"FULLRESYNC" in res + @pytest.mark.onlynoncluster + def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + + ok = False + + def helper(): + with pytest.raises(CancelledError): + # blocking pop + with patch.object( + r.connection._parser, "read_response", side_effect=CancelledError + ): + r.brpop(["nonexist"]) + # if all is well, we can continue. + r.set("status", "down") # should not hang + nonlocal ok + ok = True + + thread = threading.Thread(target=helper) + thread.start() + thread.join(0.1) + try: + assert not thread.is_alive() + assert ok + finally: + # disconnect here so that fixture cleanup can proceed + r.connection.disconnect() + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_connect.py b/tests/test_connect.py new file mode 100644 index 0000000000..f07750dc80 --- /dev/null +++ b/tests/test_connect.py @@ -0,0 +1,191 @@ +import logging +import re +import socket +import socketserver +import ssl +import threading + +import pytest +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection + +from .ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, tcp_address) + + +def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10) + _assert_connect(conn, path) + + +@pytest.mark.ssl +def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, + port=port, + client_name=_CLIENT_NAME, + ssl_ca_certs=certfile, + socket_timeout=10, + ) + _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +def _assert_connect(conn, server_address, certfile=None, keyfile=None): + if isinstance(server_address, str): + if not _RedisUDSServer: + pytest.skip("Unix domain sockets are not supported on this platform") + server = _RedisUDSServer(server_address, _RedisRequestHandler) + else: + server = _RedisTCPServer( + server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile + ) + with server as aserver: + t = threading.Thread(target=aserver.serve_forever) + t.start() + try: + aserver.wait_online() + conn.connect() + conn.disconnect() + finally: + aserver.stop() + t.join(timeout=5) + + +class _RedisTCPServer(socketserver.TCPServer): + def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + self._certfile = certfile + self._keyfile = keyfile + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + + def get_request(self): + if self._certfile is None: + return super().get_request() + newsocket, fromaddr = self.socket.accept() + connstream = ssl.wrap_socket( + newsocket, + server_side=True, + certfile=self._certfile, + keyfile=self._keyfile, + ssl_version=ssl.PROTOCOL_TLSv1_2, + ) + return connstream, fromaddr + + +if hasattr(socket, "UnixStreamServer"): + + class _RedisUDSServer(socketserver.UnixStreamServer): + def __init__(self, *args, **kw) -> None: + self._ready_event = threading.Event() + self._stop_requested = False + super().__init__(*args, **kw) + + def service_actions(self): + self._ready_event.set() + + def wait_online(self): + self._ready_event.wait() + + def stop(self): + self._stop_requested = True + self.shutdown() + + def is_serving(self): + return not self._stop_requested + +else: + _RedisUDSServer = None + + +class _RedisRequestHandler(socketserver.StreamRequestHandler): + def setup(self): + _logger.info("%s connected", self.client_address) + + def finish(self): + _logger.info("%s disconnected", self.client_address) + + def handle(self): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while self.server.is_serving() or buffer: + try: + buffer += self.request.recv(1024) + except socket.timeout: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment: %s", fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command %s", command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response %s", resp) + self.request.sendall(resp) + command = None + _logger.info("Exit handler") diff --git a/tests/test_connection.py b/tests/test_connection.py index e0b53cdf37..760b23c9c1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,10 +4,10 @@ from unittest.mock import patch import pytest - import redis +from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff -from redis.connection import Connection, HiredisParser, PythonParser +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -29,22 +29,22 @@ def test_invalid_response(r): @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod -def test_loading_external_modules(modclient): +def test_loading_external_modules(r): def inner(): pass - modclient.load_external_module("myfuncname", inner) - assert getattr(modclient, "myfuncname") == inner - assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + r.load_external_module("myfuncname", inner) + assert getattr(r, "myfuncname") == inner + assert isinstance(getattr(r, "myfuncname"), types.FunctionType) # and call it from redis.commands import RedisModuleCommands j = RedisModuleCommands.json - modclient.load_external_module("sometestfuncname", j) + r.load_external_module("sometestfuncname", j) # d = {'hello': 'world!'} - # mod = j(modclient) + # mod = j(r) # mod.set("fookey", ".", d) # assert mod.get('fookey') == d @@ -128,7 +128,9 @@ def test_connect_timeout_error_without_retry(self): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_RESP2Parser, _RESP3Parser, _HiredisParser], + ids=["RESP2Parser", "RESP3Parser", "HiredisParser"], ) def test_connection_parse_response_resume(r: redis.Redis, parser_class): """ @@ -136,7 +138,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): be that PythonParser or HiredisParser, can be interrupted at IO time and then resume parsing. """ - if parser_class is HiredisParser and not HIREDIS_AVAILABLE: + if parser_class is _HiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available)") args = dict(r.connection_pool.connection_kwargs) args["parser_class"] = parser_class @@ -148,13 +150,13 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): ) mock_socket = MockSocket(message, interrupt_every=2) - if isinstance(conn._parser, PythonParser): + if isinstance(conn._parser, _RESP2Parser) or isinstance(conn._parser, _RESP3Parser): conn._parser._buffer._sock = mock_socket else: conn._parser._sock = mock_socket for i in range(100): try: - response = conn.read_response() + response = conn.read_response(disconnect_on_error=False) break except MockSocket.TestError: pass @@ -163,3 +165,47 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): pytest.fail("didn't receive a response") assert response assert i > 0 + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize( + "Class", + [ + Connection, + SSLConnection, + UnixDomainSocketConnection, + ], +) +def test_pack_command(Class): + """ + This test verifies that the pack_command works + on all supported connections. #2581 + """ + cmd = ( + "HSET", + "foo", + "key", + "value1", + b"key_b", + b"bytes str", + b"key_i", + 67, + "key_f", + 3.14159265359, + ) + expected = ( + b"*10\r\n$4\r\nHSET\r\n$3\r\nfoo\r\n$3\r\nkey\r\n$6\r\nvalue1\r\n" + b"$5\r\nkey_b\r\n$9\r\nbytes str\r\n$5\r\nkey_i\r\n$2\r\n67\r\n$5" + b"\r\nkey_f\r\n$13\r\n3.14159265359\r\n" + ) + + actual = Class().pack_command(*cmd)[0] + assert actual == expected, f"actual = {actual}, expected = {expected}" + + +@pytest.mark.onlynoncluster +def test_create_single_connection_client_from_url(): + client = redis.Redis.from_url( + "redis://localhost:6379/0?", single_connection_client=True + ) + assert client.connection is not None diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index e8a42692a1..ab0fc6be98 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -5,9 +5,9 @@ from unittest import mock import pytest - import redis -from redis.connection import ssl_available, to_bool +from redis.connection import to_bool +from redis.utils import SSL_AVAILABLE from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt from .test_pubsub import wait_for_message @@ -425,7 +425,7 @@ class MyConnection(redis.UnixDomainSocketConnection): assert pool.connection_class == MyConnection -@pytest.mark.skipif(not ssl_available, reason="SSL not installed") +@pytest.mark.skipif(not SSL_AVAILABLE, reason="SSL not installed") class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") @@ -528,10 +528,17 @@ def test_busy_loading_from_pipeline(self, r): @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() def test_read_only_error(self, r): - "READONLY errors get turned in ReadOnlyError exceptions" + "READONLY errors get turned into ReadOnlyError exceptions" with pytest.raises(redis.ReadOnlyError): r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + def test_oom_error(self, r): + "OOM errors get turned into OutOfMemoryError exceptions" + with pytest.raises(redis.OutOfMemoryError): + # note: don't use the DEBUG OOM command since it's not the same + # as the db being full + r.execute_command("DEBUG", "ERROR", "OOM blah blah") + def test_connect_from_url_tcp(self): connection = redis.Redis.from_url("redis://localhost") pool = connection.connection_pool diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 9aeb1ef1d5..aade04e082 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -4,7 +4,6 @@ from typing import Optional, Tuple, Union import pytest - import redis from redis import AuthenticationError, DataError, ResponseError from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -198,6 +197,12 @@ def test_change_username_password_on_existing_connection(self, r, request): password = "origin_password" new_username = "new_username" new_password = "new_password" + + def teardown(): + r.acl_deluser(new_username) + + request.addfinalizer(teardown) + init_acl_user(r, request, username, password) r2 = _get_client( redis.Redis, request, flushdb=False, username=username, password=password diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 2867640742..331cd5108c 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -1,7 +1,7 @@ import pytest - import redis from redis.connection import Connection +from redis.utils import HIREDIS_PACK_AVAILABLE from .conftest import _get_client @@ -75,6 +75,10 @@ def test_replace(self, request): assert r.get("a") == "foo\ufffd" +@pytest.mark.skipif( + HIREDIS_PACK_AVAILABLE, + reason="Packing via hiredis does not preserve memoryviews", +) class TestMemoryviewsAreNotPacked: def test_memoryviews_are_not_packed(self): c = Connection() diff --git a/tests/test_function.py b/tests/test_function.py index 7ce66a38e6..9d6712ecf7 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,8 +1,7 @@ import pytest - from redis.exceptions import ResponseError -from .conftest import skip_if_server_version_lt +from .conftest import assert_resp_response, skip_if_server_version_lt engine = "lua" lib = "mylib" @@ -64,12 +63,22 @@ def test_function_list(self, r): [[b"name", b"myfunc", b"description", None, b"flags", [b"no-writes"]]], ] ] - assert r.function_list() == res - assert r.function_list(library="*lib") == res - assert ( - r.function_list(withcode=True)[0][7] - == f"#!{engine} name={lib} \n {function}".encode() + resp3_res = [ + { + b"library_name": b"mylib", + b"engine": b"LUA", + b"functions": [ + {b"name": b"myfunc", b"description": None, b"flags": {b"no-writes"}} + ], + } + ] + assert_resp_response(r, r.function_list(), res, resp3_res) + assert_resp_response(r, r.function_list(library="*lib"), res, resp3_res) + res[0].extend( + [b"library_code", f"#!{engine} name={lib} \n {function}".encode()] ) + resp3_res[0][b"library_code"] = f"#!{engine} name={lib} \n {function}".encode() + assert_resp_response(r, r.function_list(withcode=True), res, resp3_res) @pytest.mark.onlycluster def test_function_list_on_cluster(self, r): @@ -84,17 +93,28 @@ def test_function_list_on_cluster(self, r): [[b"name", b"myfunc", b"description", None, b"flags", [b"no-writes"]]], ] ] + resp3_function_list = [ + { + b"library_name": b"mylib", + b"engine": b"LUA", + b"functions": [ + {b"name": b"myfunc", b"description": None, b"flags": {b"no-writes"}} + ], + } + ] primaries = r.get_primaries() res = {} + resp3_res = {} for node in primaries: res[node.name] = function_list - assert r.function_list() == res - assert r.function_list(library="*lib") == res + resp3_res[node.name] = resp3_function_list + assert_resp_response(r, r.function_list(), res, resp3_res) + assert_resp_response(r, r.function_list(library="*lib"), res, resp3_res) node = primaries[0].name - assert ( - r.function_list(withcode=True)[node][0][7] - == f"#!{engine} name={lib} \n {function}".encode() - ) + code = f"#!{engine} name={lib} \n {function}".encode() + res[node][0].extend([b"library_code", code]) + resp3_res[node][0][b"library_code"] = code + assert_resp_response(r, r.function_list(withcode=True), res, resp3_res) def test_fcall(self, r): r.function_load(f"#!{engine} name={lib} \n {set_function}") diff --git a/tests/test_graph.py b/tests/test_graph.py index 4721b2f4e2..6fa9977d98 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest - +from redis import Redis from redis.commands.graph import Edge, Node, Path from redis.commands.graph.execution_plan import Operation from redis.commands.graph.query_result import ( @@ -20,23 +20,22 @@ QueryResult, ) from redis.exceptions import ResponseError -from tests.conftest import skip_if_redis_enterprise +from tests.conftest import _get_client, skip_if_redis_enterprise @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(request): + r = _get_client(Redis, request, decode_responses=True) + r.flushdb() + return r -@pytest.mark.redismod def test_bulk(client): with pytest.raises(NotImplementedError): client.graph().bulk() client.graph().bulk(foo="bar!") -@pytest.mark.redismod def test_graph_creation(client): graph = client.graph() @@ -81,7 +80,6 @@ def test_graph_creation(client): graph.delete() -@pytest.mark.redismod def test_array_functions(client): query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" client.graph().query(query) @@ -102,7 +100,6 @@ def test_array_functions(client): assert [a] == result.result_set[0][0] -@pytest.mark.redismod def test_path(client): node0 = Node(node_id=0, label="L1") node1 = Node(node_id=1, label="L1") @@ -122,7 +119,6 @@ def test_path(client): assert expected_results == result.result_set -@pytest.mark.redismod def test_param(client): params = [1, 2.3, "str", True, False, None, [0, 1, 2], r"\" RETURN 1337 //"] query = "RETURN $param" @@ -132,7 +128,6 @@ def test_param(client): assert expected_results == result.result_set -@pytest.mark.redismod def test_map(client): query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" @@ -149,7 +144,6 @@ def test_map(client): assert actual == expected -@pytest.mark.redismod def test_point(client): query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" expected_lat = 32.070794860 @@ -166,7 +160,6 @@ def test_point(client): assert abs(actual["longitude"] - expected_lon) < 0.001 -@pytest.mark.redismod def test_index_response(client): result_set = client.graph().query("CREATE INDEX ON :person(age)") assert 1 == result_set.indices_created @@ -181,7 +174,6 @@ def test_index_response(client): client.graph().query("DROP INDEX ON :person(age)") -@pytest.mark.redismod def test_stringify_query_result(client): graph = client.graph() @@ -235,7 +227,6 @@ def test_stringify_query_result(client): graph.delete() -@pytest.mark.redismod def test_optional_match(client): # Build a graph of form (a)-[R]->(b) node0 = Node(node_id=0, label="L1", properties={"value": "a"}) @@ -260,7 +251,6 @@ def test_optional_match(client): graph.delete() -@pytest.mark.redismod def test_cached_execution(client): client.graph().query("CREATE ()") @@ -278,7 +268,6 @@ def test_cached_execution(client): assert cached_result.cached_execution -@pytest.mark.redismod def test_slowlog(client): create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), @@ -291,7 +280,7 @@ def test_slowlog(client): assert results[0][2] == create_query -@pytest.mark.redismod +@pytest.mark.xfail(strict=False) def test_query_timeout(client): # Build a sample graph with 1000 nodes. client.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") @@ -305,7 +294,6 @@ def test_query_timeout(client): assert False is False -@pytest.mark.redismod def test_read_only_query(client): with pytest.raises(Exception): # Issue a write query, specifying read-only true, @@ -314,7 +302,6 @@ def test_read_only_query(client): assert False is False -@pytest.mark.redismod def test_profile(client): q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" profile = client.graph().profile(q).result_set @@ -329,7 +316,6 @@ def test_profile(client): assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile -@pytest.mark.redismod @skip_if_redis_enterprise() def test_config(client): config_name = "RESULTSET_SIZE" @@ -361,7 +347,6 @@ def test_config(client): client.graph().config("RESULTSET_SIZE", -100, set=True) -@pytest.mark.redismod @pytest.mark.onlynoncluster def test_list_keys(client): result = client.graph().list_keys() @@ -385,7 +370,6 @@ def test_list_keys(client): assert result == [] -@pytest.mark.redismod def test_multi_label(client): redis_graph = client.graph("g") @@ -411,7 +395,6 @@ def test_multi_label(client): assert True -@pytest.mark.redismod def test_cache_sync(client): pass return @@ -484,7 +467,6 @@ def test_cache_sync(client): assert A._relationship_types[1] == "R" -@pytest.mark.redismod def test_execution_plan(client): redis_graph = client.graph("execution_plan") create_query = """CREATE @@ -503,7 +485,6 @@ def test_execution_plan(client): redis_graph.delete() -@pytest.mark.redismod def test_explain(client): redis_graph = client.graph("execution_plan") # graph creation / population @@ -592,7 +573,6 @@ def test_explain(client): redis_graph.delete() -@pytest.mark.redismod def test_resultset_statistics(client): with patch.object(target=QueryResult, attribute="_get_stat") as mock_get_stats: result = client.graph().query("RETURN 1") diff --git a/tests/test_graph_utils/test_edge.py b/tests/test_graph_utils/test_edge.py index b5b7362389..581ebfab5d 100644 --- a/tests/test_graph_utils/test_edge.py +++ b/tests/test_graph_utils/test_edge.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.graph import edge, node diff --git a/tests/test_graph_utils/test_node.py b/tests/test_graph_utils/test_node.py index cd4e936719..c3b34ac6ff 100644 --- a/tests/test_graph_utils/test_node.py +++ b/tests/test_graph_utils/test_node.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.graph import node diff --git a/tests/test_graph_utils/test_path.py b/tests/test_graph_utils/test_path.py index d581269307..1bd38efab4 100644 --- a/tests/test_graph_utils/test_path.py +++ b/tests/test_graph_utils/test_path.py @@ -1,5 +1,4 @@ import pytest - from redis.commands.graph import edge, node, path diff --git a/tests/test_json.py b/tests/test_json.py index a776e9e736..be347f6677 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,20 +1,19 @@ import pytest - import redis -from redis import exceptions +from redis import Redis, exceptions from redis.commands.json.decoders import decode_list, unstring from redis.commands.json.path import Path -from .conftest import skip_ifmodversion_lt +from .conftest import _get_client, assert_resp_response, skip_ifmodversion_lt @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(request): + r = _get_client(Redis, request, decode_responses=True) + r.flushdb() + return r -@pytest.mark.redismod def test_json_setbinarykey(client): d = {"hello": "world", b"some": "value"} with pytest.raises(TypeError): @@ -22,40 +21,70 @@ def test_json_setbinarykey(client): assert client.json().set("somekey", Path.root_path(), d, decode_keys=True) -@pytest.mark.redismod def test_json_setgetdeleteforget(client): assert client.json().set("foo", Path.root_path(), "bar") - assert client.json().get("foo") == "bar" + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) assert client.json().get("baz") is None assert client.json().delete("foo") == 1 assert client.json().forget("foo") == 0 # second delete assert client.exists("foo") == 0 -@pytest.mark.redismod def test_jsonget(client): client.json().set("foo", Path.root_path(), "bar") - assert client.json().get("foo") == "bar" + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) -@pytest.mark.redismod def test_json_get_jset(client): assert client.json().set("foo", Path.root_path(), "bar") - assert "bar" == client.json().get("foo") + assert_resp_response(client, client.json().get("foo"), "bar", [["bar"]]) assert client.json().get("baz") is None assert 1 == client.json().delete("foo") assert client.exists("foo") == 0 -@pytest.mark.redismod +@skip_ifmodversion_lt("2.06.00", "ReJSON") # todo: update after the release +def test_json_merge(client): + # Test with root path $ + assert client.json().set( + "person_data", + "$", + {"person1": {"personal_data": {"name": "John"}}}, + ) + assert client.json().merge( + "person_data", "$", {"person1": {"personal_data": {"hobbies": "reading"}}} + ) + assert client.json().get("person_data") == { + "person1": {"personal_data": {"name": "John", "hobbies": "reading"}} + } + + # Test with root path path $.person1.personal_data + assert client.json().merge( + "person_data", "$.person1.personal_data", {"country": "Israel"} + ) + assert client.json().get("person_data") == { + "person1": { + "personal_data": {"name": "John", "hobbies": "reading", "country": "Israel"} + } + } + + # Test with null value to delete a value + assert client.json().merge("person_data", "$.person1.personal_data", {"name": None}) + assert client.json().get("person_data") == { + "person1": {"personal_data": {"country": "Israel", "hobbies": "reading"}} + } + + def test_nonascii_setgetdelete(client): assert client.json().set("notascii", Path.root_path(), "hyvää-élève") - assert "hyvää-élève" == client.json().get("notascii", no_escape=True) + res = "hyvää-élève" + assert_resp_response( + client, client.json().get("notascii", no_escape=True), res, [[res]] + ) assert 1 == client.json().delete("notascii") assert client.exists("notascii") == 0 -@pytest.mark.redismod def test_jsonsetexistentialmodifiersshouldsucceed(client): obj = {"foo": "bar"} assert client.json().set("obj", Path.root_path(), obj) @@ -73,7 +102,6 @@ def test_jsonsetexistentialmodifiersshouldsucceed(client): client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) -@pytest.mark.redismod def test_mgetshouldsucceed(client): client.json().set("1", Path.root_path(), 1) client.json().set("2", Path.root_path(), 2) @@ -82,40 +110,58 @@ def test_mgetshouldsucceed(client): assert client.json().mget([1, 2], Path.root_path()) == [1, 2] -@pytest.mark.redismod +@pytest.mark.onlynoncluster +@skip_ifmodversion_lt("2.06.00", "ReJSON") +def test_mset(client): + client.json().mset([("1", Path.root_path(), 1), ("2", Path.root_path(), 2)]) + + assert client.json().mget(["1"], Path.root_path()) == [1] + assert client.json().mget(["1", "2"], Path.root_path()) == [1, 2] + + @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release def test_clear(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 1 == client.json().clear("arr", Path.root_path()) - assert [] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [], [[[]]]) -@pytest.mark.redismod def test_type(client): client.json().set("1", Path.root_path(), 1) - assert "integer" == client.json().type("1", Path.root_path()) - assert "integer" == client.json().type("1") + assert_resp_response( + client, client.json().type("1", Path.root_path()), "integer", ["integer"] + ) + assert_resp_response(client, client.json().type("1"), "integer", ["integer"]) -@pytest.mark.redismod def test_numincrby(client): client.json().set("num", Path.root_path(), 1) - assert 2 == client.json().numincrby("num", Path.root_path(), 1) - assert 2.5 == client.json().numincrby("num", Path.root_path(), 0.5) - assert 1.25 == client.json().numincrby("num", Path.root_path(), -1.25) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), 1), 2, [2] + ) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), 0.5), 2.5, [2.5] + ) + assert_resp_response( + client, client.json().numincrby("num", Path.root_path(), -1.25), 1.25, [1.25] + ) -@pytest.mark.redismod def test_nummultby(client): client.json().set("num", Path.root_path(), 1) with pytest.deprecated_call(): - assert 2 == client.json().nummultby("num", Path.root_path(), 2) - assert 5 == client.json().nummultby("num", Path.root_path(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.root_path(), 0.5) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 2), 2, [2] + ) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 2.5), 5, [5] + ) + assert_resp_response( + client, client.json().nummultby("num", Path.root_path(), 0.5), 2.5, [2.5] + ) -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release def test_toggle(client): client.json().set("bool", Path.root_path(), False) @@ -127,15 +173,15 @@ def test_toggle(client): client.json().toggle("num", Path.root_path()) -@pytest.mark.redismod def test_strappend(client): client.json().set("jsonkey", Path.root_path(), "foo") assert 6 == client.json().strappend("jsonkey", "bar") - assert "foobar" == client.json().get("jsonkey", Path.root_path()) + assert_resp_response( + client, client.json().get("jsonkey", Path.root_path()), "foobar", [["foobar"]] + ) -# @pytest.mark.redismod -# def test_debug(client): +# # def test_debug(client): # client.json().set("str", Path.root_path(), "foo") # assert 24 == client.json().debug("MEMORY", "str", Path.root_path()) # assert 24 == client.json().debug("MEMORY", "str") @@ -144,7 +190,6 @@ def test_strappend(client): # assert isinstance(client.json().debug("HELP"), list) -@pytest.mark.redismod def test_strlen(client): client.json().set("str", Path.root_path(), "foo") assert 3 == client.json().strlen("str", Path.root_path()) @@ -153,7 +198,6 @@ def test_strlen(client): assert 6 == client.json().strlen("str") -@pytest.mark.redismod def test_arrappend(client): client.json().set("arr", Path.root_path(), [1]) assert 2 == client.json().arrappend("arr", Path.root_path(), 2) @@ -161,26 +205,30 @@ def test_arrappend(client): assert 7 == client.json().arrappend("arr", Path.root_path(), *[5, 6, 7]) -@pytest.mark.redismod def test_arrindex(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 1 == client.json().arrindex("arr", Path.root_path(), 1) assert -1 == client.json().arrindex("arr", Path.root_path(), 1, 2) + assert 4 == client.json().arrindex("arr", Path.root_path(), 4) + assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0) + assert 4 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=5000) + assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=0, stop=-1) + assert -1 == client.json().arrindex("arr", Path.root_path(), 4, start=1, stop=3) -@pytest.mark.redismod def test_arrinsert(client): client.json().set("arr", Path.root_path(), [0, 4]) assert 5 - -client.json().arrinsert("arr", Path.root_path(), 1, *[1, 2, 3]) - assert [0, 1, 2, 3, 4] == client.json().get("arr") + res = [0, 1, 2, 3, 4] + assert_resp_response(client, client.json().get("arr"), res, [[res]]) # test prepends client.json().set("val2", Path.root_path(), [5, 6, 7, 8, 9]) client.json().arrinsert("val2", Path.root_path(), 0, ["some", "thing"]) - assert client.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + res = [["some", "thing"], 5, 6, 7, 8, 9] + assert_resp_response(client, client.json().get("val2"), res, [[res]]) -@pytest.mark.redismod def test_arrlen(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 5 == client.json().arrlen("arr", Path.root_path()) @@ -188,14 +236,13 @@ def test_arrlen(client): assert client.json().arrlen("fakekey") is None -@pytest.mark.redismod def test_arrpop(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 4 == client.json().arrpop("arr", Path.root_path(), 4) assert 3 == client.json().arrpop("arr", Path.root_path(), -1) assert 2 == client.json().arrpop("arr", Path.root_path()) assert 0 == client.json().arrpop("arr", Path.root_path(), 0) - assert [1] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [1], [[[1]]]) # test out of bounds client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -206,11 +253,10 @@ def test_arrpop(client): assert client.json().arrpop("arr") is None -@pytest.mark.redismod def test_arrtrim(client): client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) assert 3 == client.json().arrtrim("arr", Path.root_path(), 1, 3) - assert [1, 2, 3] == client.json().get("arr") + assert_resp_response(client, client.json().get("arr"), [1, 2, 3], [[[1, 2, 3]]]) # <0 test, should be 0 equivalent client.json().set("arr", Path.root_path(), [0, 1, 2, 3, 4]) @@ -229,7 +275,6 @@ def test_arrtrim(client): assert 0 == client.json().arrtrim("arr", Path.root_path(), 9, 11) -@pytest.mark.redismod def test_resp(client): obj = {"foo": "bar", "baz": 1, "qaz": True} client.json().set("obj", Path.root_path(), obj) @@ -239,7 +284,6 @@ def test_resp(client): assert isinstance(client.json().resp("obj"), list) -@pytest.mark.redismod def test_objkeys(client): obj = {"foo": "bar", "baz": "qaz"} client.json().set("obj", Path.root_path(), obj) @@ -256,7 +300,6 @@ def test_objkeys(client): assert client.json().objkeys("fakekey") is None -@pytest.mark.redismod def test_objlen(client): obj = {"foo": "bar", "baz": "qaz"} client.json().set("obj", Path.root_path(), obj) @@ -266,13 +309,12 @@ def test_objlen(client): assert len(obj) == client.json().objlen("obj") -@pytest.mark.redismod def test_json_commands_in_pipeline(client): p = client.json().pipeline() p.set("foo", Path.root_path(), "bar") p.get("foo") p.delete("foo") - assert [True, "bar", 1] == p.execute() + assert_resp_response(client, p.execute(), [True, "bar", 1], [True, [["bar"]], 1]) assert client.keys() == [] assert client.get("foo") is None @@ -285,24 +327,23 @@ def test_json_commands_in_pipeline(client): p.jsonget("foo") p.exists("notarealkey") p.delete("foo") - assert [True, d, 0, 1] == p.execute() + assert_resp_response(client, p.execute(), [True, d, 0, 1], [True, [[d]], 0, 1]) assert client.keys() == [] assert client.get("foo") is None -@pytest.mark.redismod def test_json_delete_with_dollar(client): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert client.json().set("doc1", "$", doc1) assert client.json().delete("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().delete("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(client, client.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -333,8 +374,7 @@ def test_json_delete_with_dollar(client): } ] ] - res = client.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(client, client.json().get("doc3", "$"), doc3val, [doc3val]) # Test default path assert client.json().delete("doc3") == 1 @@ -343,19 +383,18 @@ def test_json_delete_with_dollar(client): client.json().delete("not_a_document", "..a") -@pytest.mark.redismod def test_json_forget_with_dollar(client): doc1 = {"a": 1, "nested": {"a": 2, "b": 3}} assert client.json().set("doc1", "$", doc1) assert client.json().forget("doc1", "$..a") == 2 - r = client.json().get("doc1", "$") - assert r == [{"nested": {"b": 3}}] + res = [{"nested": {"b": 3}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) doc2 = {"a": {"a": 2, "b": 3}, "b": ["a", "b"], "nested": {"b": [True, "a", "b"]}} assert client.json().set("doc2", "$", doc2) assert client.json().forget("doc2", "$..a") == 1 - res = client.json().get("doc2", "$") - assert res == [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + res = [{"nested": {"b": [True, "a", "b"]}, "b": ["a", "b"]}] + assert_resp_response(client, client.json().get("doc2", "$"), res, [res]) doc3 = [ { @@ -386,8 +425,7 @@ def test_json_forget_with_dollar(client): } ] ] - res = client.json().get("doc3", "$") - assert res == doc3val + assert_resp_response(client, client.json().get("doc3", "$"), doc3val, [doc3val]) # Test default path assert client.json().forget("doc3") == 1 @@ -396,7 +434,6 @@ def test_json_forget_with_dollar(client): client.json().forget("not_a_document", "..a") -@pytest.mark.redismod def test_json_mget_dollar(client): # Test mget with multi paths client.json().set( @@ -410,8 +447,10 @@ def test_json_mget_dollar(client): {"a": 4, "b": 5, "nested": {"a": 6}, "c": None, "nested2": {"a": [None]}}, ) # Compare also to single JSON.GET - assert client.json().get("doc1", "$..a") == [1, 3, None] - assert client.json().get("doc2", "$..a") == [4, 6, [None]] + res = [1, 3, None] + assert_resp_response(client, client.json().get("doc1", "$..a"), res, [res]) + res = [4, 6, [None]] + assert_resp_response(client, client.json().get("doc2", "$..a"), res, [res]) # Test mget with single path client.json().mget("doc1", "$..a") == [1, 3, None] @@ -424,7 +463,6 @@ def test_json_mget_dollar(client): assert res == [None, None] -@pytest.mark.redismod def test_numby_commands_dollar(client): # Test NUMINCRBY @@ -469,7 +507,6 @@ def test_numby_commands_dollar(client): client.json().nummultby("doc1", ".b[0].a", 3) == 6 -@pytest.mark.redismod def test_strappend_dollar(client): client.json().set( @@ -478,15 +515,14 @@ def test_strappend_dollar(client): # Test multi client.json().strappend("doc1", "bar", "$..a") == [6, 8, None] - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foobar", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single client.json().strappend("doc1", "baz", "$.nested1.a") == [11] - client.json().get("doc1", "$") == [ - {"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foobar", "nested1": {"a": "hellobarbaz"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -494,16 +530,14 @@ def test_strappend_dollar(client): # Test multi client.json().strappend("doc1", "bar", ".*.a") == 8 - client.json().get("doc1", "$") == [ - {"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}} - ] + # res = [{"a": "foo", "nested1": {"a": "hellobar"}, "nested2": {"a": 31}}] + # assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing path with pytest.raises(exceptions.ResponseError): client.json().strappend("doc1", "piu") -@pytest.mark.redismod def test_strlen_dollar(client): # Test multi @@ -525,7 +559,6 @@ def test_strlen_dollar(client): client.json().strlen("non_existing_doc", "$..a") -@pytest.mark.redismod def test_arrappend_dollar(client): client.json().set( "doc1", @@ -538,23 +571,25 @@ def test_arrappend_dollar(client): ) # Test multi client.json().arrappend("doc1", "$..a", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test single assert client.json().arrappend("doc1", "$.nested1.a", "baz") == [6] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -573,29 +608,31 @@ def test_arrappend_dollar(client): # Test multi (all paths are updated, but return result of last path) assert client.json().arrappend("doc1", "..a", "bar", "racuda") == 5 - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrappend("doc1", ".nested1.a", "baz") == 6 - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", None, "world", "bar", "racuda", "baz"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): client.json().arrappend("non_existing_doc", "$..a") -@pytest.mark.redismod def test_arrinsert_dollar(client): client.json().set( "doc1", @@ -609,29 +646,31 @@ def test_arrinsert_dollar(client): # Test multi assert client.json().arrinsert("doc1", "$..a", "1", "bar", "racuda") == [3, 5, None] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrinsert("doc1", "$.nested1.a", -2, "baz") == [6] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo", "bar", "racuda"], "nested1": {"a": ["hello", "bar", "racuda", "baz", None, "world"]}, "nested2": {"a": 31}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): client.json().arrappend("non_existing_doc", "$..a") -@pytest.mark.redismod def test_arrlen_dollar(client): client.json().set( @@ -681,7 +720,6 @@ def test_arrlen_dollar(client): assert client.json().arrlen("non_existing_doc", "..a") is None -@pytest.mark.redismod def test_arrpop_dollar(client): client.json().set( "doc1", @@ -696,9 +734,8 @@ def test_arrpop_dollar(client): # # # Test multi assert client.json().arrpop("doc1", "$..a", 1) == ['"foo"', None, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -716,16 +753,14 @@ def test_arrpop_dollar(client): ) # Test multi (all paths are updated, but return result of last path) client.json().arrpop("doc1", "..a", "1") is None - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["hello", "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # # Test missing key with pytest.raises(exceptions.ResponseError): client.json().arrpop("non_existing_doc", "..a") -@pytest.mark.redismod def test_arrtrim_dollar(client): client.json().set( @@ -739,19 +774,17 @@ def test_arrtrim_dollar(client): ) # Test multi assert client.json().arrtrim("doc1", "$..a", "1", -1) == [0, 2, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": [None, "world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) assert client.json().arrtrim("doc1", "$..a", "1", "1") == [0, 1, None] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) + # Test single assert client.json().arrtrim("doc1", "$.nested1.a", 1, 0) == [0] - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": []}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": []}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): @@ -773,16 +806,14 @@ def test_arrtrim_dollar(client): # Test single assert client.json().arrtrim("doc1", ".nested1.a", "1", "1") == 1 - assert client.json().get("doc1", "$") == [ - {"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}} - ] + res = [{"a": [], "nested1": {"a": ["world"]}, "nested2": {"a": 31}}] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): client.json().arrtrim("non_existing_doc", "..a", 1, 1) -@pytest.mark.redismod def test_objkeys_dollar(client): client.json().set( "doc1", @@ -812,7 +843,6 @@ def test_objkeys_dollar(client): assert client.json().objkeys("doc1", "$..nowhere") == [] -@pytest.mark.redismod def test_objlen_dollar(client): client.json().set( "doc1", @@ -848,7 +878,6 @@ def test_objlen_dollar(client): client.json().objlen("doc1", ".nowhere") -@pytest.mark.redismod def load_types_data(nested_key_name): td = { "object": {}, @@ -868,21 +897,23 @@ def load_types_data(nested_key_name): return jdata, types -@pytest.mark.redismod def test_type_dollar(client): jdata, jtypes = load_types_data("a") client.json().set("doc1", "$", jdata) # Test multi - assert client.json().type("doc1", "$..a") == jtypes + assert_resp_response(client, client.json().type("doc1", "$..a"), jtypes, [jtypes]) # Test single - assert client.json().type("doc1", "$.nested2.a") == [jtypes[1]] + assert_resp_response( + client, client.json().type("doc1", "$.nested2.a"), [jtypes[1]], [[jtypes[1]]] + ) # Test missing key - assert client.json().type("non_existing_doc", "..a") is None + assert_resp_response( + client, client.json().type("non_existing_doc", "..a"), None, [None] + ) -@pytest.mark.redismod def test_clear_dollar(client): client.json().set( "doc1", @@ -897,9 +928,10 @@ def test_clear_dollar(client): # Test multi assert client.json().clear("doc1", "$..a") == 3 - assert client.json().get("doc1", "$") == [ + res = [ {"nested1": {"a": {}}, "a": [], "nested2": {"a": "claro"}, "nested3": {"a": {}}} ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test single client.json().set( @@ -913,7 +945,7 @@ def test_clear_dollar(client): }, ) assert client.json().clear("doc1", "$.nested1.a") == 1 - assert client.json().get("doc1", "$") == [ + res = [ { "nested1": {"a": {}}, "a": ["foo"], @@ -921,17 +953,17 @@ def test_clear_dollar(client): "nested3": {"a": {"baz": 50}}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing path (defaults to root) assert client.json().clear("doc1") == 1 - assert client.json().get("doc1", "$") == [{}] + assert_resp_response(client, client.json().get("doc1", "$"), [{}], [[{}]]) # Test missing key with pytest.raises(exceptions.ResponseError): client.json().clear("non_existing_doc", "$..a") -@pytest.mark.redismod def test_toggle_dollar(client): client.json().set( "doc1", @@ -945,7 +977,7 @@ def test_toggle_dollar(client): ) # Test multi assert client.json().toggle("doc1", "$..a") == [None, 1, None, 0] - assert client.json().get("doc1", "$") == [ + res = [ { "a": ["foo"], "nested1": {"a": True}, @@ -953,14 +985,14 @@ def test_toggle_dollar(client): "nested3": {"a": False}, } ] + assert_resp_response(client, client.json().get("doc1", "$"), res, [res]) # Test missing key with pytest.raises(exceptions.ResponseError): client.json().toggle("non_existing_doc", "$..a") -# @pytest.mark.redismod -# def test_debug_dollar(client): +# # def test_debug_dollar(client): # # jdata, jtypes = load_types_data("a") # @@ -982,7 +1014,6 @@ def test_toggle_dollar(client): # assert client.json().debug("MEMORY", "non_existing_doc", "$..a") == [] -@pytest.mark.redismod def test_resp_dollar(client): data = { @@ -1028,7 +1059,7 @@ def test_resp_dollar(client): client.json().set("doc1", "$", data) # Test multi res = client.json().resp("doc1", "$..a") - assert res == [ + resp2 = [ [ "{", "A1_B1", @@ -1084,10 +1115,67 @@ def test_resp_dollar(client): ["{", "A2_B4_C1", "bar"], ], ] + resp3 = [ + [ + "{", + "A1_B1", + 10, + "A1_B2", + "false", + "A1_B3", + [ + "{", + "A1_B3_C1", + None, + "A1_B3_C2", + [ + "[", + "A1_B3_C2_D1_1", + "A1_B3_C2_D1_2", + -19.5, + "A1_B3_C2_D1_4", + "A1_B3_C2_D1_5", + ["{", "A1_B3_C2_D1_6_E1", "true"], + ], + "A1_B3_C3", + ["[", 1], + ], + "A1_B4", + ["{", "A1_B4_C1", "foo"], + ], + [ + "{", + "A2_B1", + 20, + "A2_B2", + "false", + "A2_B3", + [ + "{", + "A2_B3_C1", + None, + "A2_B3_C2", + [ + "[", + "A2_B3_C2_D1_1", + "A2_B3_C2_D1_2", + -37.5, + "A2_B3_C2_D1_4", + "A2_B3_C2_D1_5", + ["{", "A2_B3_C2_D1_6_E1", "false"], + ], + "A2_B3_C3", + ["[", 2], + ], + "A2_B4", + ["{", "A2_B4_C1", "bar"], + ], + ] + assert_resp_response(client, res, resp2, resp3) # Test single - resSingle = client.json().resp("doc1", "$.L1.a") - assert resSingle == [ + res = client.json().resp("doc1", "$.L1.a") + resp2 = [ [ "{", "A1_B1", @@ -1116,6 +1204,36 @@ def test_resp_dollar(client): ["{", "A1_B4_C1", "foo"], ] ] + resp3 = [ + [ + "{", + "A1_B1", + 10, + "A1_B2", + "false", + "A1_B3", + [ + "{", + "A1_B3_C1", + None, + "A1_B3_C2", + [ + "[", + "A1_B3_C2_D1_1", + "A1_B3_C2_D1_2", + -19.5, + "A1_B3_C2_D1_4", + "A1_B3_C2_D1_5", + ["{", "A1_B3_C2_D1_6_E1", "true"], + ], + "A1_B3_C3", + ["[", 1], + ], + "A1_B4", + ["{", "A1_B4_C1", "foo"], + ] + ] + assert_resp_response(client, res, resp2, resp3) # Test missing path client.json().resp("doc1", "$.nowhere") @@ -1125,7 +1243,6 @@ def test_resp_dollar(client): client.json().resp("non_existing_doc", "$..a") -@pytest.mark.redismod def test_arrindex_dollar(client): client.json().set( @@ -1170,10 +1287,13 @@ def test_arrindex_dollar(client): }, ) - assert client.json().get("store", "$.store.book[?(@.price<10)].size") == [ - [10, 20, 30, 40], - [5, 10, 20, 30], - ] + assert_resp_response( + client, + client.json().get("store", "$.store.book[?(@.price<10)].size"), + [[10, 20, 30, 40], [5, 10, 20, 30]], + [[[10, 20, 30, 40], [5, 10, 20, 30]]], + ) + assert client.json().arrindex( "store", "$.store.book[?(@.price<10)].size", "20" ) == [-1, -1] @@ -1194,13 +1314,14 @@ def test_arrindex_dollar(client): ], ) - assert client.json().get("test_num", "$..arr") == [ + res = [ [0, 1, 3.0, 3, 2, 1, 0, 3], [5, 4, 3, 2, 1, 0, 1, 2, 3.0, 2, 4, 5], [2, 4, 6], "3", [], ] + assert_resp_response(client, client.json().get("test_num", "$..arr"), res, [res]) assert client.json().arrindex("test_num", "$..arr", 3) == [3, 2, -1, None, -1] @@ -1226,13 +1347,14 @@ def test_arrindex_dollar(client): ], ], ) - assert client.json().get("test_string", "$..arr") == [ + res = [ ["bazzz", "bar", 2, "baz", 2, "ba", "baz", 3], [None, "baz2", "buzz", 2, 1, 0, 1, "2", "baz", 2, 4, 5], ["baz2", 4, 6], "3", [], ] + assert_resp_response(client, client.json().get("test_string", "$..arr"), res, [res]) assert client.json().arrindex("test_string", "$..arr", "baz") == [ 3, @@ -1318,13 +1440,14 @@ def test_arrindex_dollar(client): ], ], ) - assert client.json().get("test_None", "$..arr") == [ + res = [ ["bazzz", "None", 2, None, 2, "ba", "baz", 3], ["zaz", "baz2", "buzz", 2, 1, 0, 1, "2", None, 2, 4, 5], ["None", 4, 6], None, [], ] + assert_resp_response(client, client.json().get("test_None", "$..arr"), res, [res]) # Test with none-scalar value assert client.json().arrindex( @@ -1346,7 +1469,6 @@ def test_arrindex_dollar(client): assert client.json().arrindex("test_None", "..nested2_not_found.arr", "None") == 0 -@pytest.mark.redismod def test_decoders_and_unstring(): assert unstring("4") == 4 assert unstring("45.55") == 45.55 @@ -1357,7 +1479,6 @@ def test_decoders_and_unstring(): assert decode_list(["hello", b"world"]) == ["hello", "world"] -@pytest.mark.redismod def test_custom_decoder(client): import json @@ -1365,7 +1486,7 @@ def test_custom_decoder(client): cj = client.json(encoder=ujson, decoder=ujson) assert cj.set("foo", Path.root_path(), "bar") - assert "bar" == cj.get("foo") + assert_resp_response(client, cj.get("foo"), "bar", [["bar"]]) assert cj.get("baz") is None assert 1 == cj.delete("foo") assert client.exists("foo") == 0 @@ -1373,7 +1494,6 @@ def test_custom_decoder(client): assert not isinstance(cj.__decoder__, json.JSONDecoder) -@pytest.mark.redismod def test_set_file(client): import json import tempfile @@ -1387,12 +1507,11 @@ def test_set_file(client): nojsonfile.write(b"Hello World") assert client.json().set_file("test", Path.root_path(), jsonfile.name) - assert client.json().get("test") == obj + assert_resp_response(client, client.json().get("test"), obj, [[obj]]) with pytest.raises(json.JSONDecodeError): client.json().set_file("test2", Path.root_path(), nojsonfile.name) -@pytest.mark.redismod def test_set_path(client): import json import tempfile @@ -1409,4 +1528,7 @@ def test_set_path(client): result = {jsonfile: True, nojsonfile: False} assert client.json().set_path(Path.root_path(), root) == result - assert client.json().get(jsonfile.rsplit(".")[0]) == {"hello": "world"} + res = {"hello": "world"} + assert_resp_response( + client, client.json().get(jsonfile.rsplit(".")[0]), res, [[res]] + ) diff --git a/tests/test_lock.py b/tests/test_lock.py index 10ad7e1539..b34f7f0159 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -1,7 +1,6 @@ import time import pytest - from redis.client import Redis from redis.exceptions import LockError, LockNotOwnedError from redis.lock import Lock @@ -102,11 +101,12 @@ def test_blocking_timeout(self, r): assert lock1.acquire(blocking=False) bt = 0.4 sleep = 0.05 + fudge_factor = 0.05 lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) start = time.monotonic() assert not lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt > (time.monotonic() - start) > bt - sleep + assert (bt + fudge_factor) > (time.monotonic() - start) > bt - sleep lock1.release() def test_context_manager(self, r): @@ -120,11 +120,12 @@ def test_context_manager_blocking_timeout(self, r): with self.get_lock(r, "foo", blocking=False): bt = 0.4 sleep = 0.05 + fudge_factor = 0.05 lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) start = time.monotonic() assert not lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt > (time.monotonic() - start) > bt - sleep + assert (bt + fudge_factor) > (time.monotonic() - start) > bt - sleep def test_context_manager_raises_when_locked_not_acquired(self, r): r.set("foo", "bar") diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 32f5e23d53..5cda3190a6 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -2,7 +2,6 @@ import multiprocessing import pytest - import redis from redis.connection import Connection, ConnectionPool from redis.exceptions import ConnectionError diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 716cd0fbf6..7b048eec01 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,4 @@ import pytest - import redis from .conftest import skip_if_server_version_lt, wait_for_command @@ -19,7 +18,6 @@ def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert pipe.execute() == [ True, @@ -27,7 +25,6 @@ def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] def test_pipeline_memoryview(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5d86934de6..ba097e3194 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -3,24 +3,39 @@ import socket import threading import time +from collections import defaultdict from unittest import mock from unittest.mock import patch import pytest - import redis from redis.exceptions import ConnectionError +from redis.utils import HIREDIS_AVAILABLE -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + _get_client, + is_resp2_connection, + skip_if_redis_enterprise, + skip_if_server_version_lt, +) -def wait_for_message(pubsub, timeout=0.5, ignore_subscribe_messages=False): +def wait_for_message( + pubsub, timeout=0.5, ignore_subscribe_messages=False, node=None, func=None +): now = time.time() timeout = now + timeout while now < timeout: - message = pubsub.get_message( - ignore_subscribe_messages=ignore_subscribe_messages - ) + if node: + message = pubsub.get_sharded_message( + ignore_subscribe_messages=ignore_subscribe_messages, target_node=node + ) + elif func: + message = func(ignore_subscribe_messages=ignore_subscribe_messages) + else: + message = pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) if message is not None: return message time.sleep(0.01) @@ -47,6 +62,15 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.unsubscribe, "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } + elif type == "shard_channel": + return { + "p": pubsub, + "sub_type": "ssubscribe", + "unsub_type": "sunsubscribe", + "sub_func": pubsub.ssubscribe, + "unsub_func": pubsub.sunsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } elif type == "pattern": return { "p": pubsub, @@ -87,6 +111,44 @@ def test_pattern_subscribe_unsubscribe(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe_cluster(self, r): + node_channels = defaultdict(int) + p = r.pubsub() + keys = { + "foo": r.get_node_from_key("foo"), + "bar": r.get_node_from_key("bar"), + "uni" + chr(4456) + "code": r.get_node_from_key("uni" + chr(4456) + "code"), + } + + for key, node in keys.items(): + assert p.ssubscribe(key) is None + + # should be a message for each shard_channel we just subscribed to + for key, node in keys.items(): + node_channels[node.name] += 1 + assert wait_for_message(p, node=node) == make_message( + "ssubscribe", key, node_channels[node.name] + ) + + for key in keys.keys(): + assert p.sunsubscribe(key) is None + + # should be a message for each shard_channel we just unsubscribed + # from + for key, node in keys.items(): + node_channels[node.name] -= 1 + assert wait_for_message(p, node=node) == make_message( + "sunsubscribe", key, node_channels[node.name] + ) + def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -130,6 +192,12 @@ def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_resubscribe_on_reconnection(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_resubscribe_to_shard_channels_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_resubscribe_on_reconnection(**kwargs) + def _test_subscribed_property( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -186,38 +254,111 @@ def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_subscribed_property(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_subscribed_property(**kwargs) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_subscribe_property_with_shard_channels_cluster(self, r): + p = r.pubsub() + keys = ["foo", "bar", "uni" + chr(4456) + "code"] + nodes = [r.get_node_from_key(key) for key in keys] + assert p.subscribed is False + p.ssubscribe(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all shard_channels + p.sunsubscribe() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + p.ssubscribe(keys[0]) + assert p.subscribed is True + assert wait_for_message(p, node=nodes[0]) == make_message( + "ssubscribe", keys[0], 1 + ) + + # unsubscribe again + p.sunsubscribe() + assert p.subscribed is True + # subscribe to another shard_channel before reading the unsubscribe response + p.ssubscribe(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert wait_for_message(p, node=nodes[0]) == make_message( + "sunsubscribe", keys[0], 0 + ) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert wait_for_message(p, node=nodes[1]) == make_message( + "ssubscribe", keys[1], 1 + ) + p.sunsubscribe() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert wait_for_message(p, node=nodes[1]) == make_message( + "sunsubscribe", keys[1], 0 + ) + # now we're finally unsubscribed + assert p.subscribed is False + + @skip_if_server_version_lt("7.0.0") def test_ignore_all_subscribe_messages(self, r): p = r.pubsub(ignore_subscribe_messages=True) checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - assert wait_for_message(p) is None + assert wait_for_message(p, func=get_func) is None assert p.subscribed is False + @skip_if_server_version_lt("7.0.0") def test_ignore_individual_subscribe_messages(self, r): p = r.pubsub() checks = ( - (p.subscribe, "foo"), - (p.unsubscribe, "foo"), - (p.psubscribe, "f*"), - (p.punsubscribe, "f*"), + (p.subscribe, "foo", p.get_message), + (p.unsubscribe, "foo", p.get_message), + (p.psubscribe, "f*", p.get_message), + (p.punsubscribe, "f*", p.get_message), + (p.ssubscribe, "foo", p.get_sharded_message), + (p.sunsubscribe, "foo", p.get_sharded_message), ) assert p.subscribed is False - for func, channel in checks: + for func, channel, get_func in checks: assert func(channel) is None assert p.subscribed is True - message = wait_for_message(p, ignore_subscribe_messages=True) + message = wait_for_message(p, ignore_subscribe_messages=True, func=get_func) assert message is None assert p.subscribed is False @@ -230,6 +371,12 @@ def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_resub(**kwargs) + def _test_sub_unsub_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -244,6 +391,26 @@ def _test_sub_unsub_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe(key) + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + def test_sub_unsub_all_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "channel") self._test_sub_unsub_all_resub(**kwargs) @@ -252,6 +419,12 @@ def test_sub_unsub_all_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), "pattern") self._test_sub_unsub_all_resub(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "shard_channel") + self._test_sub_unsub_all_resub(**kwargs) + def _test_sub_unsub_all_resub( self, p, sub_type, unsub_type, sub_func, unsub_func, keys ): @@ -266,6 +439,26 @@ def _test_sub_unsub_all_resub( assert wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_sub_unsub_all_resub_shard_channels_cluster(self, r): + p = r.pubsub() + key = "foo" + p.ssubscribe(key) + p.sunsubscribe() + p.ssubscribe(key) + assert p.subscribed is True + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "sunsubscribe", key, 0 + ) + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", key, 1 + ) + assert p.subscribed is True + class TestPubSubMessages: def setup_method(self, method): @@ -284,6 +477,32 @@ def test_published_message_to_channel(self, r): assert isinstance(message, dict) assert message == make_message("message", "foo", "test message") + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p) == make_message("ssubscribe", "foo", 1) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_published_message_to_shard_channel_cluster(self, r): + p = r.pubsub() + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) == make_message( + "ssubscribe", "foo", 1 + ) + assert r.spublish("foo", "test message") == 1 + + message = wait_for_message(p, func=p.get_sharded_message) + assert isinstance(message, dict) + assert message == make_message("smessage", "foo", "test message") + def test_published_message_to_pattern(self, r): p = r.pubsub() p.subscribe("foo") @@ -315,6 +534,15 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(foo=self.message_handler) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == make_message("smessage", "foo", "test message") + @pytest.mark.onlynoncluster def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -336,6 +564,17 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") + @skip_if_server_version_lt("7.0.0") + def test_unicode_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + p.ssubscribe(**channels) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert r.spublish(channel, "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == make_message("smessage", channel, "test message") + @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html # #known-limitations-with-pubsub @@ -352,6 +591,36 @@ def test_unicode_pattern_message_handler(self, r): ) +class TestPubSubRESP3Handler: + def my_handler(self, message): + self.message = ["my handler", message] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + def test_push_handler(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.subscribe("foo") + assert wait_for_message(p) is None + assert self.message == ["my handler", [b"subscribe", b"foo", 1]] + assert r.publish("foo", "test message") == 1 + assert wait_for_message(p) is None + assert self.message == ["my handler", [b"message", b"foo", b"test message"]] + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") + @skip_if_server_version_lt("7.0.0") + def test_push_handler_sharded_pubsub(self, r): + if is_resp2_connection(r): + return + p = r.pubsub(push_handler_func=self.my_handler) + p.ssubscribe("foo") + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"ssubscribe", b"foo", 1]] + assert r.spublish("foo", "test message") == 1 + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == ["my handler", [b"smessage", b"foo", b"test message"]] + + class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" @@ -388,6 +657,19 @@ def test_pattern_subscribe_unsubscribe(self, r): p.punsubscribe(self.pattern) assert wait_for_message(p) == self.make_message("punsubscribe", self.pattern, 0) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) + + p.sunsubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "sunsubscribe", self.channel, 0 + ) + def test_channel_publish(self, r): p = r.pubsub() p.subscribe(self.channel) @@ -407,6 +689,18 @@ def test_pattern_publish(self, r): "pmessage", self.channel, self.data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_publish(self, r): + p = r.pubsub() + p.ssubscribe(self.channel) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "ssubscribe", self.channel, 1 + ) + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) == self.make_message( + "smessage", self.channel, self.data + ) + def test_channel_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.subscribe(**{self.channel: self.message_handler}) @@ -445,6 +739,30 @@ def test_pattern_message_handler(self, r): "pmessage", self.channel, new_data, pattern=self.pattern ) + @skip_if_server_version_lt("7.0.0") + def test_shard_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + p.ssubscribe(**{self.channel: self.message_handler}) + assert wait_for_message(p, func=p.get_sharded_message) is None + r.spublish(self.channel, self.data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + try: + # cluster mode + p.disconnect() + except AttributeError: + # standalone mode + p.connection.disconnect() + # should reconnect + assert wait_for_message(p, func=p.get_sharded_message) is None + new_data = self.data + "new data" + r.spublish(self.channel, new_data) + assert wait_for_message(p, func=p.get_sharded_message) is None + assert self.message == self.make_message("smessage", self.channel, new_data) + def test_context_manager(self, r): with r.pubsub() as pubsub: pubsub.subscribe("foo") @@ -474,6 +792,38 @@ def test_pubsub_channels(self, r): expected = [b"bar", b"baz", b"foo", b"quux"] assert all([channel in r.pubsub_channels() for channel in expected]) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels(self, r): + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert wait_for_message(p)["type"] == "ssubscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in r.pubsub_shardchannels() for channel in expected]) + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardchannels_cluster(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + b"quux": r.get_node_from_key("quux"), + } + p = r.pubsub() + p.ssubscribe("foo", "bar", "baz", "quux") + for node in channels.values(): + assert wait_for_message(p, node=node)["type"] == "ssubscribe" + for channel, node in channels.items(): + assert channel in r.pubsub_shardchannels(target_nodes=node) + assert all( + [ + channel in r.pubsub_shardchannels(target_nodes="all") + for channel in channels.keys() + ] + ) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.0") def test_pubsub_numsub(self, r): @@ -500,6 +850,32 @@ def test_pubsub_numpat(self, r): assert wait_for_message(p)["type"] == "psubscribe" assert r.pubsub_numpat() == 3 + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_pubsub_shardnumsub(self, r): + channels = { + b"foo": r.get_node_from_key("foo"), + b"bar": r.get_node_from_key("bar"), + b"baz": r.get_node_from_key("baz"), + } + p1 = r.pubsub() + p1.ssubscribe(*channels.keys()) + for node in channels.values(): + assert wait_for_message(p1, node=node)["type"] == "ssubscribe" + p2 = r.pubsub() + p2.ssubscribe("bar", "baz") + for i in range(2): + assert ( + wait_for_message(p2, func=p2.get_sharded_message)["type"] + == "ssubscribe" + ) + p3 = r.pubsub() + p3.ssubscribe("baz") + assert wait_for_message(p3, node=channels[b"baz"])["type"] == "ssubscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels + class TestPubSubPings: @skip_if_server_version_lt("3.0.0") @@ -767,13 +1143,15 @@ def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert is_connected() - with patch("redis.connection.PythonParser.read_response") as mock1: + with patch("redis._parsers._RESP2Parser.read_response") as mock1, patch( + "redis._parsers._HiredisParser.read_response" + ) as mock2, patch("redis._parsers._RESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") - with patch("redis.connection.HiredisParser.read_response") as mock2: - mock2.side_effect = BaseException("boom") + mock2.side_effect = BaseException("boom") + mock3.side_effect = BaseException("boom") - with pytest.raises(BaseException): - get_msg() + with pytest.raises(BaseException): + get_msg() # the timeout on the read should not cause disconnect assert is_connected() diff --git a/tests/test_retry.py b/tests/test_retry.py index 3cfea5c09e..e9d3015897 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,7 +1,6 @@ from unittest.mock import patch import pytest - from redis.backoff import ExponentialBackoff, NoBackoff from redis.client import Redis from redis.connection import Connection, UnixDomainSocketConnection diff --git a/tests/test_scripting.py b/tests/test_scripting.py index b6b5f9fb70..899dc69482 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,5 +1,4 @@ import pytest - import redis from redis import exceptions from redis.commands.core import Script diff --git a/tests/test_search.py b/tests/test_search.py index 57d4338ead..2e42aaba57 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -5,7 +5,6 @@ from io import TextIOWrapper import pytest - import redis import redis.commands.search import redis.commands.search.aggregation as aggregations @@ -24,7 +23,13 @@ from redis.commands.search.result import Result from redis.commands.search.suggestion import Suggestion -from .conftest import skip_if_redis_enterprise, skip_ifmodversion_lt +from .conftest import ( + _get_client, + assert_resp_response, + is_resp2_connection, + skip_if_redis_enterprise, + skip_ifmodversion_lt, +) WILL_PLAY_TEXT = os.path.abspath( os.path.join(os.path.dirname(__file__), "testdata", "will_play_text.csv.bz2") @@ -40,12 +45,16 @@ def waitForIndex(env, idx, timeout=None): while True: res = env.execute_command("FT.INFO", idx) try: - res.index("indexing") + if int(res[res.index("indexing") + 1]) == 0: + break except ValueError: break - - if int(res[res.index("indexing") + 1]) == 0: - break + except AttributeError: + try: + if int(res["indexing"]) == 0: + break + except ValueError: + break time.sleep(delay) if timeout is not None: @@ -98,9 +107,10 @@ def createIndex(client, num_docs=100, definition=None): @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(request): + r = _get_client(redis.Redis, request, decode_responses=True) + r.flushdb() + return r @pytest.mark.redismod @@ -133,82 +143,170 @@ def test_client(client): assert num_docs == int(info["num_docs"]) res = client.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" + if is_resp2_connection(client): + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc["id"] + assert doc.play == "Henry IV" + assert doc["play"] == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content()).total + vtotal = client.ft().search(Query("kings").no_content().verbatim()).total + assert total > vtotal + + # test in fields + txt_total = ( + client.ft().search(Query("henry").no_content().limit_fields("txt")).total + ) + play_total = ( + client.ft().search(Query("henry").no_content().limit_fields("play")).total + ) + both_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play", "txt")) + .total + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search(Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search(Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search(Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") + # test in-keys + ids = [x.id for x in client.ft().search(Query("henry")).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king")).total + assert 3 == client.ft().search(Query("henry king").slop(0).in_order()).total + assert 52 == client.ft().search(Query("king henry").slop(0).in_order()).total + assert 53 == client.ft().search(Query("henry king").slop(0)).total + assert 167 == client.ft().search(Query("henry king").slop(100)).total + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + client.ft().delete_document("doc-5ghs2") + else: + assert isinstance(res, dict) + assert 225 == res["total_results"] + assert 10 == len(res["results"]) + + for doc in res["results"]: + assert doc["id"] + assert doc["extra_attributes"]["play"] == "Henry IV" + assert len(doc["extra_attributes"]["txt"]) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res["total_results"] + assert 10 == len(res["results"]) + for doc in res["results"]: + assert "extra_attributes" not in doc.keys() + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content())["total_results"] + vtotal = client.ft().search(Query("kings").no_content().verbatim())[ + "total_results" + ] + assert total > vtotal + + # test in fields + txt_total = client.ft().search(Query("henry").no_content().limit_fields("txt"))[ + "total_results" + ] + play_total = client.ft().search( + Query("henry").no_content().limit_fields("play") + )["total_results"] + both_total = client.ft().search( + Query("henry").no_content().limit_fields("play", "txt") + )["total_results"] + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [x["id"] for x in client.ft().search(Query("henry"))["results"]] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs["total_results"] + ids = [x["id"] for x in docs["results"]] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king"))["total_results"] + assert ( + 3 + == client.ft().search(Query("henry king").slop(0).in_order())[ + "total_results" + ] + ) + assert ( + 52 + == client.ft().search(Query("king henry").slop(0).in_order())[ + "total_results" + ] + ) + assert 53 == client.ft().search(Query("henry king").slop(0))["total_results"] + assert 167 == client.ft().search(Query("henry king").slop(100))["total_results"] + + # test delete document + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res["total_results"] + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.hset("doc-5ghs2", mapping={"play": "Death of a Salesman"}) + res = client.ft().search(Query("death of a salesman")) + assert 1 == res["total_results"] + client.ft().delete_document("doc-5ghs2") @pytest.mark.redismod @@ -221,12 +319,16 @@ def test_scores(client): q = Query("foo ~bar").with_scores() res = client.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + assert 3.0 == res["results"][0]["score"] + assert "doc1" == res["results"][1]["id"] @pytest.mark.redismod @@ -239,8 +341,12 @@ def test_stopwords(client): q1 = Query("foo bar").no_content() q2 = Query("foo bar hello world").no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total + if is_resp2_connection(client): + assert 0 == res1.total + assert 1 == res2.total + else: + assert 0 == res1["total_results"] + assert 1 == res2["total_results"] @pytest.mark.redismod @@ -260,25 +366,40 @@ def test_filters(client): .no_content() ) res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id + if is_resp2_connection(client): + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + else: + assert 1 == res1["total_results"] + assert 1 == res2["total_results"] + assert "doc2" == res1["results"][0]["id"] + assert "doc1" == res2["results"][0]["id"] # Test geo filter q1 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 10)).no_content() q2 = Query("foo").add_filter(GeoFilter("loc", -0.44, 51.45, 100)).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id + if is_resp2_connection(client): + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + else: + assert 1 == res1["total_results"] + assert 2 == res2["total_results"] + assert "doc1" == res1["results"][0]["id"] + + # Sort results, after RDB reload order may change + res = [res2["results"][0]["id"], res2["results"][1]["id"]] + res.sort() + assert ["doc1", "doc2"] == res @pytest.mark.redismod @@ -293,14 +414,24 @@ def test_sort_by(client): q2 = Query("foo").sort_by("num", asc=False).no_content() res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id + if is_resp2_connection(client): + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + else: + assert 3 == res1["total_results"] + assert "doc1" == res1["results"][0]["id"] + assert "doc2" == res1["results"][1]["id"] + assert "doc3" == res1["results"][2]["id"] + assert 3 == res2["total_results"] + assert "doc1" == res2["results"][2]["id"] + assert "doc2" == res2["results"][1]["id"] + assert "doc3" == res2["results"][0]["id"] @pytest.mark.redismod @@ -415,27 +546,50 @@ def test_no_index(client): ) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total + if is_resp2_connection(client): + res = client.ft().search(Query("@text:aa*")) + assert 0 == res.total + + res = client.ft().search(Query("@field:aa*")) + assert 2 == res.total - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + else: + res = client.ft().search(Query("@text:aa*")) + assert 0 == res["total_results"] - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search(Query("@field:aa*")) + assert 2 == res["total_results"] + + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res["total_results"] + assert "doc2" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res["results"][0]["id"] + + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res["results"][0]["id"] # Ensure exception is raised for non-indexable, non-sortable fields with pytest.raises(Exception): @@ -470,21 +624,38 @@ def test_summarize(client): q.highlight(fields=("play", "txt"), tags=("", "")) q.summarize("txt") - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + if is_resp2_connection(client): + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) - q = Query("king henry").paging(0, 1).summarize().highlight() + q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + else: + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry IV" == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted(client.ft().search(q)["results"])[0] + assert "Henry ... " == doc["extra_attributes"]["play"] + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc["extra_attributes"]["txt"] + ) @pytest.mark.redismod @@ -504,25 +675,46 @@ def test_alias(client): index1.hset("index1:lonestar", mapping={"name": "lonestar"}) index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id + if is_resp2_connection(client): + res = ftindex1.search("*").docs[0] + assert "index1:lonestar" == res.id - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient(client).ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient(client).ft("spaceballs") + res = alias_client.search("*").docs[0] + assert "index1:lonestar" == res.id - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") + + res = alias_client2.search("*").docs[0] + assert "index2:yogurt" == res.id + else: + res = ftindex1.search("*")["results"][0] + assert "index1:lonestar" == res["id"] - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient(client).ft("spaceballs") + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient(client).ft("spaceballs") + res = alias_client.search("*")["results"][0] + assert "index1:lonestar" == res["id"] - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient(client).ft("spaceballs") + + res = alias_client2.search("*")["results"][0] + assert "index2:yogurt" == res["id"] ftindex2.aliasdel("spaceballs") with pytest.raises(Exception): @@ -530,9 +722,9 @@ def test_alias(client): @pytest.mark.redismod +@pytest.mark.xfail(strict=False) def test_alias_basic(client): # Creating a client with one index - getClient(client).flushdb() index1 = getClient(client).ft("testAlias") index1.create_index((TextField("txt"),)) @@ -545,18 +737,32 @@ def test_alias_basic(client): # add the actual alias and check index1.aliasadd("myalias") alias_client = getClient(client).ft("myalias") - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient(client).ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id + if is_resp2_connection(client): + res = sorted(alias_client.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + else: + res = sorted(alias_client.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient(client).ft("myalias") + res = sorted(alias_client2.search("*")["results"], key=lambda x: x["id"]) + assert "doc1" == res[0]["id"] # delete the alias and expect an error if we try to query again index2.aliasdel("myalias") @@ -571,8 +777,12 @@ def test_textfield_sortable_nostem(client): # Now get the index info to confirm its contents response = client.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] + if is_resp2_connection(client): + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + else: + assert "SORTABLE" in response["attributes"][0]["flags"] + assert "NOSTEM" in response["attributes"][0]["flags"] @pytest.mark.redismod @@ -593,7 +803,10 @@ def test_alter_schema_add(client): # Ensure we find the result searching on the added body field res = client.ft().search(q) - assert 1 == res.total + if is_resp2_connection(client): + assert 1 == res.total + else: + assert 1 == res["total_results"] @pytest.mark.redismod @@ -606,33 +819,64 @@ def test_spell_check(client): client.hset("doc2", mapping={"f1": "very important", "f2": "lorem ipsum"}) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} + if is_resp2_connection(client): + + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = client.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} + else: + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" in res["results"]["impornant"][0].keys() + + res = client.ft().spellcheck("contnt") + assert "content" in res["results"]["contnt"][0].keys() + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {"results": {"vlis": []}} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" in res["results"]["vlis"][0].keys() + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["results"]["lorm"]) == 3 + assert "lorem" in res["results"]["lorm"][0].keys() + assert "lore" in res["results"]["lorm"][1].keys() + assert "lorm" in res["results"]["lorm"][2].keys() + assert ( + res["results"]["lorm"][0]["lorem"], + res["results"]["lorm"][1]["lore"], + ) == (0.5, 0) + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {"results": {}} @pytest.mark.redismod @@ -648,7 +892,7 @@ def test_dict_operations(client): # Dump dict and inspect content res = client.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res + assert_resp_response(client, res, ["item1", "item3"], {"item1", "item3"}) # Remove rest of the items before reload client.ft().dict_del("custom_dict", *res) @@ -661,8 +905,12 @@ def test_phonetic_matcher(client): client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + else: + assert 1 == res["total_results"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] # Drop and create index with phonetic matcher client.flushdb() @@ -672,8 +920,14 @@ def test_phonetic_matcher(client): client.hset("doc2", mapping={"name": "John"}) res = client.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted(d.name for d in res.docs) + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted(d.name for d in res.docs) + else: + assert 2 == res["total_results"] + assert ["John", "Jon"] == sorted( + d["extra_attributes"]["name"] for d in res["results"] + ) @pytest.mark.redismod @@ -692,20 +946,36 @@ def test_scorer(client): ) # default scorer is TFIDF - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.1111111111111111 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + if is_resp2_connection(client): + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + else: + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res["results"][0]["score"] + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res["results"][0]["score"] @pytest.mark.redismod @@ -786,101 +1056,212 @@ def test_aggregations_groupby(client): }, ) - req = aggregations.AggregateRequest("redis").group_by("@parent", reducers.count()) + if is_resp2_connection(client): + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinct("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.count_distinctish("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.sum("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "21" # 10+8+3 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.min("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3" # min(10,8,3) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.max("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "10" # max(10,8,3) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.avg("@random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - index = res.index("__generated_aliasavgrandom_num") - assert res[index + 1] == "7" # (10+3+8)/3 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + index = res.index("__generated_aliasavgrandom_num") + assert res[index + 1] == "7" # (10+3+8)/3 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.stddev("random_num") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "3.60555127546" + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.quantile("@random_num", 0.5) - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[3] == "8" # median of 3,8,10 + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "8" # median of 3,8,10 - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.tolist("@title") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert set(res[3]) == {"RediSearch", "RedisAI", "RedisJson"} - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.first_value("@title").alias("first") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) - res = client.ft().aggregate(req).rows[0] - assert res == ["parent", "redis", "first", "RediSearch"] + res = client.ft().aggregate(req).rows[0] + assert res == ["parent", "redis", "first", "RediSearch"] - req = aggregations.AggregateRequest("redis").group_by( - "@parent", reducers.random_sample("@title", 2).alias("random") - ) + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + else: + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count() + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount_distincttitle"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliascount_distinctishtitle"] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliassumrandom_num"] == "21" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasminrandom_num"] == "3" - res = client.ft().aggregate(req).rows[0] - assert res[1] == "redis" - assert res[2] == "random" - assert len(res[3]) == 2 - assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasmaxrandom_num"] == "10" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasavgrandom_num"] == "7" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert ( + res["extra_attributes"]["__generated_aliasstddevrandom_num"] + == "3.60555127546" + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5) + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert res["extra_attributes"]["__generated_aliasquantilerandom_num,0.5"] == "8" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert set(res["extra_attributes"]["__generated_aliastolisttitle"]) == { + "RediSearch", + "RedisAI", + "RedisJson", + } + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.first_value("@title").alias("first") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"] == {"parent": "redis", "first": "RediSearch"} + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.random_sample("@title", 2).alias("random") + ) + + res = client.ft().aggregate(req)["results"][0] + assert res["extra_attributes"]["parent"] == "redis" + assert "random" in res["extra_attributes"].keys() + assert len(res["extra_attributes"]["random"]) == 2 + assert res["extra_attributes"]["random"][0] in [ + "RediSearch", + "RedisAI", + "RedisJson", + ] @pytest.mark.redismod @@ -890,30 +1271,56 @@ def test_aggregations_sort_by_and_limit(client): client.ft().client.hset("doc1", mapping={"t1": "a", "t2": "b"}) client.ft().client.hset("doc2", mapping={"t1": "b", "t2": "a"}) - # test sort_by using SortDirection - req = aggregations.AggregateRequest("*").sort_by( - aggregations.Asc("@t2"), aggregations.Desc("@t1") - ) - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "a", "t1", "b"] - assert res.rows[1] == ["t2", "b", "t1", "a"] + if is_resp2_connection(client): + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "a", "t1", "b"] + assert res.rows[1] == ["t2", "b", "t1", "a"] - # test sort_by without SortDirection - req = aggregations.AggregateRequest("*").sort_by("@t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "a"] - assert res.rows[1] == ["t1", "b"] + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "a"] + assert res.rows[1] == ["t1", "b"] - # test sort_by with max - req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 - # test limit - req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) - res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["t1", "b"] + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ["t1", "b"] + else: + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*").sort_by( + aggregations.Asc("@t2"), aggregations.Desc("@t1") + ) + res = client.ft().aggregate(req)["results"] + assert res[0]["extra_attributes"] == {"t2": "a", "t1": "b"} + assert res[1]["extra_attributes"] == {"t2": "b", "t1": "a"} + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*").sort_by("@t1") + res = client.ft().aggregate(req)["results"] + assert res[0]["extra_attributes"] == {"t1": "a"} + assert res[1]["extra_attributes"] == {"t1": "b"} + + # test sort_by with max + req = aggregations.AggregateRequest("*").sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + + # test limit + req = aggregations.AggregateRequest("*").sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"t1": "b"} @pytest.mark.redismod @@ -922,20 +1329,36 @@ def test_aggregations_load(client): client.ft().client.hset("doc1", mapping={"t1": "hello", "t2": "world"}) - # load t1 - req = aggregations.AggregateRequest("*").load("t1") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello"] + if is_resp2_connection(client): + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello"] - # load t2 - req = aggregations.AggregateRequest("*").load("t2") - res = client.ft().aggregate(req) - assert res.rows[0] == ["t2", "world"] + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res.rows[0] == ["t2", "world"] - # load all - req = aggregations.AggregateRequest("*").load() - res = client.ft().aggregate(req) - assert res.rows[0] == ["t1", "hello", "t2", "world"] + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res.rows[0] == ["t1", "hello", "t2", "world"] + else: + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t1": "hello"} + + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t2": "world"} + + # load all + req = aggregations.AggregateRequest("*").load() + res = client.ft().aggregate(req) + assert res["results"][0]["extra_attributes"] == {"t1": "hello", "t2": "world"} @pytest.mark.redismod @@ -960,8 +1383,17 @@ def test_aggregations_apply(client): CreatedDateTimeUTC="@CreatedDateTimeUTC * 10" ) res = client.ft().aggregate(req) - res_set = set([res.rows[0][1], res.rows[1][1]]) - assert res_set == set(["6373878785249699840", "6373878758592700416"]) + if is_resp2_connection(client): + res_set = set([res.rows[0][1], res.rows[1][1]]) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) + else: + res_set = set( + [ + res["results"][0]["extra_attributes"]["CreatedDateTimeUTC"], + res["results"][1]["extra_attributes"]["CreatedDateTimeUTC"], + ], + ) + assert res_set == set(["6373878785249699840", "6373878758592700416"]) @pytest.mark.redismod @@ -980,19 +1412,34 @@ def test_aggregations_filter(client): .dialect(dialect) ) res = client.ft().aggregate(req) - assert len(res.rows) == 1 - assert res.rows[0] == ["name", "foo", "age", "19"] - - req = ( - aggregations.AggregateRequest("*") - .filter("@age > 15") - .sort_by("@age") - .dialect(dialect) - ) - res = client.ft().aggregate(req) - assert len(res.rows) == 2 - assert res.rows[0] == ["age", "19"] - assert res.rows[1] == ["age", "25"] + if is_resp2_connection(client): + assert len(res.rows) == 1 + assert res.rows[0] == ["name", "foo", "age", "19"] + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ["age", "19"] + assert res.rows[1] == ["age", "25"] + else: + assert len(res["results"]) == 1 + assert res["results"][0]["extra_attributes"] == {"name": "foo", "age": "19"} + + req = ( + aggregations.AggregateRequest("*") + .filter("@age > 15") + .sort_by("@age") + .dialect(dialect) + ) + res = client.ft().aggregate(req) + assert len(res["results"]) == 2 + assert res["results"][0]["extra_attributes"] == {"age": "19"} + assert res["results"][1]["extra_attributes"] == {"age": "25"} @pytest.mark.redismod @@ -1058,7 +1505,11 @@ def test_skip_initial_scan(client): q = Query("@foo:bar") client.ft().create_index((TextField("foo"),), skip_initial_scan=True) - assert 0 == client.ft().search(q).total + res = client.ft().search(q) + if is_resp2_connection(client): + assert res.total == 0 + else: + assert res["total_results"] == 0 @pytest.mark.redismod @@ -1146,10 +1597,15 @@ def test_create_client_definition_json(client): client.json().set("king:2", Path.root_path(), {"name": "james"}) res = client.ft().search("henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 + if is_resp2_connection(client): + assert res.docs[0].id == "king:1" + assert res.docs[0].payload is None + assert res.docs[0].json == '{"name":"henry"}' + assert res.total == 1 + else: + assert res["results"][0]["id"] == "king:1" + assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry"}' + assert res["total_results"] == 1 @pytest.mark.redismod @@ -1167,11 +1623,17 @@ def test_fields_as_name(client): res = client.json().set("doc:1", Path.root_path(), {"name": "Jon", "age": 25}) assert res - total = client.ft().search(Query("Jon").return_fields("name", "just_a_number")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "Jon" == total[0].name - assert "25" == total[0].just_a_number + res = client.ft().search(Query("Jon").return_fields("name", "just_a_number")) + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "doc:1" == res.docs[0].id + assert "Jon" == res.docs[0].name + assert "25" == res.docs[0].just_a_number + else: + assert 1 == len(res["results"]) + assert "doc:1" == res["results"][0]["id"] + assert "Jon" == res["results"][0]["extra_attributes"]["name"] + assert "25" == res["results"][0]["extra_attributes"]["just_a_number"] @pytest.mark.redismod @@ -1182,11 +1644,16 @@ def test_casesensitive(client): client.ft().client.hset("1", "t", "HELLO") client.ft().client.hset("2", "t", "hello") - res = client.ft().search("@t:{HELLO}").docs + res = client.ft().search("@t:{HELLO}") - assert 2 == len(res) - assert "1" == res[0].id - assert "2" == res[1].id + if is_resp2_connection(client): + assert 2 == len(res.docs) + assert "1" == res.docs[0].id + assert "2" == res.docs[1].id + else: + assert 2 == len(res["results"]) + assert "1" == res["results"][0]["id"] + assert "2" == res["results"][1]["id"] # create casesensitive index client.ft().dropindex() @@ -1194,9 +1661,13 @@ def test_casesensitive(client): client.ft().create_index(SCHEMA) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - res = client.ft().search("@t:{HELLO}").docs - assert 1 == len(res) - assert "1" == res[0].id + res = client.ft().search("@t:{HELLO}") + if is_resp2_connection(client): + assert 1 == len(res.docs) + assert "1" == res.docs[0].id + else: + assert 1 == len(res["results"]) + assert "1" == res["results"][0]["id"] @pytest.mark.redismod @@ -1215,15 +1686,26 @@ def test_search_return_fields(client): client.ft().create_index(SCHEMA, definition=definition) waitForIndex(client, getattr(client.ft(), "index_name", "idx")) - total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt + if is_resp2_connection(client): + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "riceratops" == total[0].txt - total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "telmatosaurus" == total[0].txt + else: + total = client.ft().search(Query("*").return_field("$.t", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "riceratops" == total["results"][0]["extra_attributes"]["txt"] + + total = client.ft().search(Query("*").return_field("$.t2", as_field="txt")) + assert 1 == len(total["results"]) + assert "doc:1" == total["results"][0]["id"] + assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"] @pytest.mark.redismod @@ -1240,9 +1722,14 @@ def test_synupdate(client): client.hset("doc2", mapping={"title": "he is another baby", "body": "another test"}) res = client.ft().search(Query("child").expander("SYNONYM")) - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" + if is_resp2_connection(client): + assert res.docs[0].id == "doc2" + assert res.docs[0].title == "he is another baby" + assert res.docs[0].body == "another test" + else: + assert res["results"][0]["id"] == "doc2" + assert res["results"][0]["extra_attributes"]["title"] == "he is another baby" + assert res["results"][0]["extra_attributes"]["body"] == "another test" @pytest.mark.redismod @@ -1282,15 +1769,28 @@ def test_create_json_with_alias(client): client.json().set("king:1", Path.root_path(), {"name": "henry", "num": 42}) client.json().set("king:2", Path.root_path(), {"name": "james", "num": 3.14}) - res = client.ft().search("@name:henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","num":42}' - assert res.total == 1 - - res = client.ft().search("@num:[0 10]") - assert res.docs[0].id == "king:2" - assert res.docs[0].json == '{"name":"james","num":3.14}' - assert res.total == 1 + if is_resp2_connection(client): + res = client.ft().search("@name:henry") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","num":42}' + assert res.total == 1 + + res = client.ft().search("@num:[0 10]") + assert res.docs[0].id == "king:2" + assert res.docs[0].json == '{"name":"james","num":3.14}' + assert res.total == 1 + else: + res = client.ft().search("@name:henry") + assert res["results"][0]["id"] == "king:1" + assert res["results"][0]["extra_attributes"]["$"] == '{"name":"henry","num":42}' + assert res["total_results"] == 1 + + res = client.ft().search("@num:[0 10]") + assert res["results"][0]["id"] == "king:2" + assert ( + res["results"][0]["extra_attributes"]["$"] == '{"name":"james","num":3.14}' + ) + assert res["total_results"] == 1 # Tests returns an error if path contain special characters (user should # use an alias) @@ -1314,15 +1814,32 @@ def test_json_with_multipath(client): "king:1", Path.root_path(), {"name": "henry", "country": {"name": "england"}} ) - res = client.ft().search("@name:{henry}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 + if is_resp2_connection(client): + res = client.ft().search("@name:{henry}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + + res = client.ft().search("@name:{england}") + assert res.docs[0].id == "king:1" + assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' + assert res.total == 1 + else: + res = client.ft().search("@name:{henry}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 - res = client.ft().search("@name:{england}") - assert res.docs[0].id == "king:1" - assert res.docs[0].json == '{"name":"henry","country":{"name":"england"}}' - assert res.total == 1 + res = client.ft().search("@name:{england}") + assert res["results"][0]["id"] == "king:1" + assert ( + res["results"][0]["extra_attributes"]["$"] + == '{"name":"henry","country":{"name":"england"}}' + ) + assert res["total_results"] == 1 @pytest.mark.redismod @@ -1339,21 +1856,40 @@ def test_json_with_jsonpath(client): client.json().set("doc:1", Path.root_path(), {"prod:name": "RediSearch"}) - # query for a supported field succeeds - res = client.ft().search(Query("@name:RediSearch")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].json == '{"prod:name":"RediSearch"}' + if is_resp2_connection(client): + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].json == '{"prod:name":"RediSearch"}' + + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res.total == 1 + + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res.total == 1 + assert res.docs[0].id == "doc:1" + assert res.docs[0].name == "RediSearch" + else: + # query for a supported field succeeds + res = client.ft().search(Query("@name:RediSearch")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert ( + res["results"][0]["extra_attributes"]["$"] == '{"prod:name":"RediSearch"}' + ) - # query for an unsupported field - res = client.ft().search("@name_unsupported:RediSearch") - assert res.total == 1 + # query for an unsupported field + res = client.ft().search("@name_unsupported:RediSearch") + assert res["total_results"] == 1 - # return of a supported field succeeds - res = client.ft().search(Query("@name:RediSearch").return_field("name")) - assert res.total == 1 - assert res.docs[0].id == "doc:1" - assert res.docs[0].name == "RediSearch" + # return of a supported field succeeds + res = client.ft().search(Query("@name:RediSearch").return_field("name")) + assert res["total_results"] == 1 + assert res["results"][0]["id"] == "doc:1" + assert res["results"][0]["extra_attributes"]["name"] == "RediSearch" @pytest.mark.redismod @@ -1366,24 +1902,43 @@ def test_profile(client): # check using Query q = Query("hello|world").no_content() - res, det = client.ft().profile(q) - assert det["Iterators profile"]["Counter"] == 2.0 - assert len(det["Iterators profile"]["Child iterators"]) == 2 - assert det["Iterators profile"]["Type"] == "UNION" - assert det["Parsing time"] < 0.5 - assert len(res.docs) == 2 # check also the search result - - # check using AggregateRequest - req = ( - aggregations.AggregateRequest("*") - .load("t") - .apply(prefix="startswith(@t, 'hel')") - ) - res, det = client.ft().profile(req) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "WILDCARD" - assert isinstance(det["Parsing time"], float) - assert len(res.rows) == 2 # check also the search result + if is_resp2_connection(client): + res, det = client.ft().profile(q) + assert det["Iterators profile"]["Counter"] == 2.0 + assert len(det["Iterators profile"]["Child iterators"]) == 2 + assert det["Iterators profile"]["Type"] == "UNION" + assert det["Parsing time"] < 0.5 + assert len(res.docs) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res, det = client.ft().profile(req) + assert det["Iterators profile"]["Counter"] == 2 + assert det["Iterators profile"]["Type"] == "WILDCARD" + assert isinstance(det["Parsing time"], float) + assert len(res.rows) == 2 # check also the search result + else: + res = client.ft().profile(q) + assert res["profile"]["Iterators profile"][0]["Counter"] == 2.0 + assert res["profile"]["Iterators profile"][0]["Type"] == "UNION" + assert res["profile"]["Parsing time"] < 0.5 + assert len(res["results"]) == 2 # check also the search result + + # check using AggregateRequest + req = ( + aggregations.AggregateRequest("*") + .load("t") + .apply(prefix="startswith(@t, 'hel')") + ) + res = client.ft().profile(req) + assert res["profile"]["Iterators profile"][0]["Counter"] == 2 + assert res["profile"]["Iterators profile"][0]["Type"] == "WILDCARD" + assert isinstance(res["profile"]["Parsing time"], float) + assert len(res["results"]) == 2 # check also the search result @pytest.mark.redismod @@ -1396,134 +1951,174 @@ def test_profile_limited(client): client.ft().client.hset("4", "t", "helowa") q = Query("%hell% hel*") - res, det = client.ft().profile(q, limited=True) - assert ( - det["Iterators profile"]["Child iterators"][0]["Child iterators"] - == "The number of iterators in the union is 3" - ) - assert ( - det["Iterators profile"]["Child iterators"][1]["Child iterators"] - == "The number of iterators in the union is 4" - ) - assert det["Iterators profile"]["Type"] == "INTERSECT" - assert len(res.docs) == 3 # check also the search result + if is_resp2_connection(client): + res, det = client.ft().profile(q, limited=True) + assert ( + det["Iterators profile"]["Child iterators"][0]["Child iterators"] + == "The number of iterators in the union is 3" + ) + assert ( + det["Iterators profile"]["Child iterators"][1]["Child iterators"] + == "The number of iterators in the union is 4" + ) + assert det["Iterators profile"]["Type"] == "INTERSECT" + assert len(res.docs) == 3 # check also the search result + else: + res = client.ft().profile(q, limited=True) + iterators_profile = res["profile"]["Iterators profile"] + assert ( + iterators_profile[0]["Child iterators"][0]["Child iterators"] + == "The number of iterators in the union is 3" + ) + assert ( + iterators_profile[0]["Child iterators"][1]["Child iterators"] + == "The number of iterators in the union is 4" + ) + assert iterators_profile[0]["Type"] == "INTERSECT" + assert len(res["results"]) == 3 # check also the search result @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_profile_query_params(modclient: redis.Redis): - modclient.flushdb() - modclient.ft().create_index( +def test_profile_query_params(client): + client.ft().create_index( ( VectorField( "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} ), ) ) - modclient.hset("a", "v", "aaaaaaaa") - modclient.hset("b", "v", "aaaabaaa") - modclient.hset("c", "v", "aaaaabaa") + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) - res, det = modclient.ft().profile(q, query_params={"vec": "aaaaaaaa"}) - assert det["Iterators profile"]["Counter"] == 2.0 - assert det["Iterators profile"]["Type"] == "VECTOR" - assert res.total == 2 - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") + if is_resp2_connection(client): + res, det = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + assert det["Iterators profile"]["Counter"] == 2.0 + assert det["Iterators profile"]["Type"] == "VECTOR" + assert res.total == 2 + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + res = client.ft().profile(q, query_params={"vec": "aaaaaaaa"}) + assert res["profile"]["Iterators profile"][0]["Counter"] == 2 + assert res["profile"]["Iterators profile"][0]["Type"] == "VECTOR" + assert res["total_results"] == 2 + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_vector_field(modclient): - modclient.flushdb() - modclient.ft().create_index( +def test_vector_field(client): + client.flushdb() + client.ft().create_index( ( VectorField( "v", "HNSW", {"TYPE": "FLOAT32", "DIM": 2, "DISTANCE_METRIC": "L2"} ), ) ) - modclient.hset("a", "v", "aaaaaaaa") - modclient.hset("b", "v", "aaaabaaa") - modclient.hset("c", "v", "aaaaabaa") + client.hset("a", "v", "aaaaaaaa") + client.hset("b", "v", "aaaabaaa") + client.hset("c", "v", "aaaaabaa") query = "*=>[KNN 2 @v $vec]" q = Query(query).return_field("__v_score").sort_by("__v_score", True).dialect(2) - res = modclient.ft().search(q, query_params={"vec": "aaaaaaaa"}) + res = client.ft().search(q, query_params={"vec": "aaaaaaaa"}) - assert "a" == res.docs[0].id - assert "0" == res.docs[0].__getattribute__("__v_score") + if is_resp2_connection(client): + assert "a" == res.docs[0].id + assert "0" == res.docs[0].__getattribute__("__v_score") + else: + assert "a" == res["results"][0]["id"] + assert "0" == res["results"][0]["extra_attributes"]["__v_score"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_vector_field_error(modclient): - modclient.flushdb() +def test_vector_field_error(r): + r.flushdb() # sortable tag with pytest.raises(Exception): - modclient.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) + r.ft().create_index((VectorField("v", "HNSW", {}, sortable=True),)) # not supported algorithm with pytest.raises(Exception): - modclient.ft().create_index((VectorField("v", "SORT", {}),)) + r.ft().create_index((VectorField("v", "SORT", {}),)) @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_text_params(modclient): - modclient.flushdb() - modclient.ft().create_index((TextField("name"),)) +def test_text_params(client): + client.flushdb() + client.ft().create_index((TextField("name"),)) - modclient.hset("doc1", mapping={"name": "Alice"}) - modclient.hset("doc2", mapping={"name": "Bob"}) - modclient.hset("doc3", mapping={"name": "Carol"}) + client.hset("doc1", mapping={"name": "Alice"}) + client.hset("doc2", mapping={"name": "Bob"}) + client.hset("doc3", mapping={"name": "Carol"}) params_dict = {"name1": "Alice", "name2": "Bob"} q = Query("@name:($name1 | $name2 )").dialect(2) - res = modclient.ft().search(q, query_params=params_dict) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id + res = client.ft().search(q, query_params=params_dict) + if is_resp2_connection(client): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_numeric_params(modclient): - modclient.flushdb() - modclient.ft().create_index((NumericField("numval"),)) +def test_numeric_params(client): + client.flushdb() + client.ft().create_index((NumericField("numval"),)) - modclient.hset("doc1", mapping={"numval": 101}) - modclient.hset("doc2", mapping={"numval": 102}) - modclient.hset("doc3", mapping={"numval": 103}) + client.hset("doc1", mapping={"numval": 101}) + client.hset("doc2", mapping={"numval": 102}) + client.hset("doc3", mapping={"numval": 103}) params_dict = {"min": 101, "max": 102} q = Query("@numval:[$min $max]").dialect(2) - res = modclient.ft().search(q, query_params=params_dict) + res = client.ft().search(q, query_params=params_dict) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id + if is_resp2_connection(client): + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + else: + assert 2 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_geo_params(modclient): +def test_geo_params(client): - modclient.flushdb() - modclient.ft().create_index((GeoField("g"))) - modclient.hset("doc1", mapping={"g": "29.69465, 34.95126"}) - modclient.hset("doc2", mapping={"g": "29.69350, 34.94737"}) - modclient.hset("doc3", mapping={"g": "29.68746, 34.94882"}) + client.ft().create_index((GeoField("g"))) + client.hset("doc1", mapping={"g": "29.69465, 34.95126"}) + client.hset("doc2", mapping={"g": "29.69350, 34.94737"}) + client.hset("doc3", mapping={"g": "29.68746, 34.94882"}) params_dict = {"lat": "34.95126", "lon": "29.69465", "radius": 1000, "units": "km"} q = Query("@g:[$lon $lat $radius $units]").dialect(2) - res = modclient.ft().search(q, query_params=params_dict) - assert 3 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "doc3" == res.docs[2].id + res = client.ft().search(q, query_params=params_dict) + if is_resp2_connection(client): + assert 3 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + assert "doc3" == res.docs[2].id + else: + assert 3 == res["total_results"] + assert "doc1" == res["results"][0]["id"] + assert "doc2" == res["results"][1]["id"] + assert "doc3" == res["results"][2]["id"] @pytest.mark.redismod @@ -1536,29 +2131,41 @@ def test_search_commands_in_pipeline(client): q = Query("foo bar").with_payloads() p.search(q) res = p.execute() - assert res[:3] == ["OK", True, True] - assert 2 == res[3][0] - assert "doc1" == res[3][1] - assert "doc2" == res[3][4] - assert res[3][5] is None - assert res[3][3] == res[3][6] == ["txt", "foo bar"] + if is_resp2_connection(client): + assert res[:3] == ["OK", True, True] + assert 2 == res[3][0] + assert "doc1" == res[3][1] + assert "doc2" == res[3][4] + assert res[3][5] is None + assert res[3][3] == res[3][6] == ["txt", "foo bar"] + else: + assert res[:3] == ["OK", True, True] + assert 2 == res[3]["total_results"] + assert "doc1" == res[3]["results"][0]["id"] + assert "doc2" == res[3]["results"][1]["id"] + assert res[3]["results"][0]["payload"] is None + assert ( + res[3]["results"][0]["extra_attributes"] + == res[3]["results"][1]["extra_attributes"] + == {"txt": "foo bar"} + ) @pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("2.4.3", "search") -def test_dialect_config(modclient: redis.Redis): - assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "1"} - assert modclient.ft().config_set("DEFAULT_DIALECT", 2) - assert modclient.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} +def test_dialect_config(client): + assert client.ft().config_get("DEFAULT_DIALECT") + client.ft().config_set("DEFAULT_DIALECT", 2) + assert client.ft().config_get("DEFAULT_DIALECT") == {"DEFAULT_DIALECT": "2"} with pytest.raises(redis.ResponseError): - modclient.ft().config_set("DEFAULT_DIALECT", 0) + client.ft().config_set("DEFAULT_DIALECT", 0) @pytest.mark.redismod @skip_ifmodversion_lt("2.4.3", "search") -def test_dialect(modclient: redis.Redis): - modclient.ft().create_index( +def test_dialect(client): + client.ft().create_index( ( TagField("title"), TextField("t1"), @@ -1569,68 +2176,94 @@ def test_dialect(modclient: redis.Redis): ), ) ) - modclient.hset("h", "t1", "hello") + client.hset("h", "t1", "hello") with pytest.raises(redis.ResponseError) as err: - modclient.ft().explain(Query("(*)").dialect(1)) + client.ft().explain(Query("(*)").dialect(1)) assert "Syntax error" in str(err) - assert "WILDCARD" in modclient.ft().explain(Query("(*)").dialect(2)) + assert "WILDCARD" in client.ft().explain(Query("(*)").dialect(2)) with pytest.raises(redis.ResponseError) as err: - modclient.ft().explain(Query("$hello").dialect(1)) + client.ft().explain(Query("$hello").dialect(1)) assert "Syntax error" in str(err) q = Query("$hello").dialect(2) expected = "UNION {\n hello\n +hello(expanded)\n}\n" - assert expected in modclient.ft().explain(q, query_params={"hello": "hello"}) + assert expected in client.ft().explain(q, query_params={"hello": "hello"}) expected = "NUMERIC {0.000000 <= @num <= 10.000000}\n" - assert expected in modclient.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) + assert expected in client.ft().explain(Query("@title:(@num:[0 10])").dialect(1)) with pytest.raises(redis.ResponseError) as err: - modclient.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) + client.ft().explain(Query("@title:(@num:[0 10])").dialect(2)) assert "Syntax error" in str(err) @pytest.mark.redismod -def test_expire_while_search(modclient: redis.Redis): - modclient.ft().create_index((TextField("txt"),)) - modclient.hset("hset:1", "txt", "a") - modclient.hset("hset:2", "txt", "b") - modclient.hset("hset:3", "txt", "c") - assert 3 == modclient.ft().search(Query("*")).total - modclient.pexpire("hset:2", 300) - for _ in range(500): - modclient.ft().search(Query("*")).docs[1] - time.sleep(1) - assert 2 == modclient.ft().search(Query("*")).total +def test_expire_while_search(client: redis.Redis): + client.ft().create_index((TextField("txt"),)) + client.hset("hset:1", "txt", "a") + client.hset("hset:2", "txt", "b") + client.hset("hset:3", "txt", "c") + if is_resp2_connection(client): + assert 3 == client.ft().search(Query("*")).total + client.pexpire("hset:2", 300) + for _ in range(500): + client.ft().search(Query("*")).docs[1] + time.sleep(1) + assert 2 == client.ft().search(Query("*")).total + else: + assert 3 == client.ft().search(Query("*"))["total_results"] + client.pexpire("hset:2", 300) + for _ in range(500): + client.ft().search(Query("*"))["results"][1] + time.sleep(1) + assert 2 == client.ft().search(Query("*"))["total_results"] @pytest.mark.redismod @pytest.mark.experimental -def test_withsuffixtrie(modclient: redis.Redis): +def test_withsuffixtrie(client: redis.Redis): # create index - assert modclient.ft().create_index((TextField("txt"),)) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" not in info["attributes"][0] - assert modclient.ft().dropindex("idx") - - # create withsuffixtrie index (text fiels) - assert modclient.ft().create_index((TextField("t", withsuffixtrie=True))) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] - assert modclient.ft().dropindex("idx") - - # create withsuffixtrie index (tag field) - assert modclient.ft().create_index((TagField("t", withsuffixtrie=True))) - waitForIndex(modclient, getattr(modclient.ft(), "index_name", "idx")) - info = modclient.ft().info() - assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert client.ft().create_index((TextField("txt"),)) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + if is_resp2_connection(client): + info = client.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert client.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert client.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0] + else: + info = client.ft().info() + assert "WITHSUFFIXTRIE" not in info["attributes"][0]["flags"] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (text fiels) + assert client.ft().create_index((TextField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] + assert client.ft().dropindex("idx") + + # create withsuffixtrie index (tag field) + assert client.ft().create_index((TagField("t", withsuffixtrie=True))) + waitForIndex(client, getattr(client.ft(), "index_name", "idx")) + info = client.ft().info() + assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"] @pytest.mark.redismod -def test_query_timeout(modclient: redis.Redis): +def test_query_timeout(r: redis.Redis): q1 = Query("foo").timeout(5000) assert q1.get_args() == ["foo", "TIMEOUT", 5000, "LIMIT", 0, 10] q2 = Query("foo").timeout("not_a_number") with pytest.raises(redis.ResponseError): - modclient.ft().search(q2) + r.ft().search(q2) diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 8542a0bfc3..b7bcc27de2 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -1,7 +1,6 @@ import socket import pytest - import redis.sentinel from redis import exceptions from redis.sentinel import ( @@ -98,6 +97,15 @@ def test_discover_master_error(sentinel): sentinel.discover_master("xxx") +@pytest.mark.onlynoncluster +def test_dead_pool(sentinel): + master = sentinel.master_for("mymaster", db=9) + conn = master.connection_pool.get_connection("_") + conn.disconnect() + del master + conn.connect() + + @pytest.mark.onlynoncluster def test_discover_master_sentinel_down(cluster, sentinel, master_ip): # Put first sentinel 'foo' down diff --git a/tests/test_ssl.py b/tests/test_ssl.py index ed38a3166b..465fdabb89 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,14 +1,13 @@ -import os import socket import ssl from urllib.parse import urlparse import pytest - import redis from redis.exceptions import ConnectionError, RedisError from .conftest import skip_if_cryptography, skip_if_nocryptography +from .ssl_utils import get_ssl_filename @pytest.mark.ssl @@ -19,17 +18,8 @@ class TestSSL: and connecting to the appropriate port. """ - ROOT = os.path.join(os.path.dirname(__file__), "..") - CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) - if not os.path.isdir(CERT_DIR): # github actions package validation case - CERT_DIR = os.path.abspath( - os.path.join(ROOT, "..", "docker", "stunnel", "keys") - ) - if not os.path.isdir(CERT_DIR): - raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") - - SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") - SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + SERVER_CERT = get_ssl_filename("server-cert.pem") + SERVER_KEY = get_ssl_filename("server-key.pem") def test_ssl_with_invalid_cert(self, request): ssl_url = request.config.option.redis_ssl_url diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 6ced5359f7..4ab86cd56e 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -3,35 +3,34 @@ from time import sleep import pytest - import redis -from .conftest import skip_ifmodversion_lt +from .conftest import assert_resp_response, is_resp2_connection, skip_ifmodversion_lt @pytest.fixture -def client(modclient): - modclient.flushdb() - return modclient +def client(decoded_r): + decoded_r.flushdb() + return decoded_r -@pytest.mark.redismod def test_create(client): assert client.ts().create(1) assert client.ts().create(2, retention_msecs=5) assert client.ts().create(3, labels={"Redis": "Labs"}) assert client.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) info = client.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] + assert_resp_response( + client, 20, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Series" == info["labels"]["Time"] # Test for a chunk size of 128 Bytes assert client.ts().create("time-serie-1", chunk_size=128) info = client.ts().info("time-serie-1") - assert 128, info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") def test_create_duplicate_policy(client): # Test for duplicate policy @@ -39,33 +38,48 @@ def test_create_duplicate_policy(client): ts_name = f"time-serie-ooo-{duplicate_policy}" assert client.ts().create(ts_name, duplicate_policy=duplicate_policy) info = client.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy + assert_resp_response( + client, + duplicate_policy, + info.get("duplicate_policy"), + info.get("duplicatePolicy"), + ) -@pytest.mark.redismod def test_alter(client): assert client.ts().create(1) - assert 0 == client.ts().info(1).retention_msecs + info = client.ts().info(1) + assert_resp_response( + client, 0, info.get("retention_msecs"), info.get("retentionTime") + ) assert client.ts().alter(1, retention_msecs=10) - assert {} == client.ts().info(1).labels - assert 10, client.ts().info(1).retention_msecs + assert {} == client.ts().info(1)["labels"] + info = client.ts().info(1) + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) assert client.ts().alter(1, labels={"Time": "Series"}) - assert "Series" == client.ts().info(1).labels["Time"] - assert 10 == client.ts().info(1).retention_msecs + assert "Series" == client.ts().info(1)["labels"]["Time"] + info = client.ts().info(1) + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") def test_alter_diplicate_policy(client): assert client.ts().create(1) info = client.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + client, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) assert client.ts().alter(1, duplicate_policy="min") info = client.ts().info(1) - assert "min" == info.duplicate_policy + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) -@pytest.mark.redismod def test_add(client): assert 1 == client.ts().add(1, 1, 1) assert 2 == client.ts().add(2, 2, 3, retention_msecs=10) @@ -77,16 +91,17 @@ def test_add(client): assert abs(time.time() - float(client.ts().add(5, "*", 1)) / 1000) < 1.0 info = client.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] + assert_resp_response( + client, 10, info.get("retention_msecs"), info.get("retentionTime") + ) + assert "Labs" == info["labels"]["Redis"] # Test for a chunk size of 128 Bytes on TS.ADD assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") def test_add_duplicate_policy(client): @@ -124,13 +139,11 @@ def test_add_duplicate_policy(client): assert 5.0 == client.ts().get("time-serie-add-ooo-min")[1] -@pytest.mark.redismod def test_madd(client): client.ts().create("a") assert [1, 2, 3] == client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) -@pytest.mark.redismod def test_incrby_decrby(client): for _ in range(100): assert client.ts().incrby(1, 1) @@ -142,24 +155,23 @@ def test_incrby_decrby(client): assert 0 == client.ts().get(1)[1] assert client.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (5, 1.5), [5, 1.5]) assert client.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (7, 3.75), [7, 3.75]) assert client.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == client.ts().get(2) + assert_resp_response(client, client.ts().get(2), (15, 2.25), [15, 2.25]) # Test for a chunk size of 128 Bytes on TS.INCRBY assert client.ts().incrby("time-serie-1", 10, chunk_size=128) info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) # Test for a chunk size of 128 Bytes on TS.DECRBY assert client.ts().decrby("time-serie-2", 10, chunk_size=128) info = client.ts().info("time-serie-2") - assert 128 == info.chunk_size + assert_resp_response(client, 128, info.get("chunk_size"), info.get("chunkSize")) -@pytest.mark.redismod def test_create_and_delete_rule(client): # test rule creation time = 100 @@ -172,15 +184,17 @@ def test_create_and_delete_rule(client): client.ts().add(1, time * 2, 1.5) assert round(client.ts().get(2)[1], 5) == 1.5 info = client.ts().info(1) - assert info.rules[0][1] == 100 + if is_resp2_connection(client): + assert info.rules[0][1] == 100 + else: + assert info["rules"]["2"][0] == 100 # test rule deletion client.ts().deleterule(1, 2) info = client.ts().info(1) - assert not info.rules + assert not info["rules"] -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") def test_del_range(client): try: @@ -192,10 +206,9 @@ def test_del_range(client): client.ts().add(1, i, i % 7) assert 22 == client.ts().delete(1, 0, 21) assert [] == client.ts().range(1, 0, 21) - assert [(22, 1.0)] == client.ts().range(1, 22, 22) + assert_resp_response(client, client.ts().range(1, 22, 22), [(22, 1.0)], [[22, 1.0]]) -@pytest.mark.redismod def test_range(client): for i in range(100): client.ts().add(1, i, i % 7) @@ -210,7 +223,6 @@ def test_range(client): assert 10 == len(client.ts().range(1, 0, 500, count=10)) -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") def test_range_advanced(client): for i in range(100): @@ -227,18 +239,19 @@ def test_range_advanced(client): filter_by_max_value=2, ) ) - assert [(0, 10.0), (10, 1.0)] == client.ts().range( + res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" ) - assert [(0, 5.0), (5, 6.0)] == client.ts().range( + assert_resp_response(client, res, [(0, 10.0), (10, 1.0)], [[0, 10.0], [10, 1.0]]) + res = client.ts().range( 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 ) - assert [(0, 2.55), (10, 3.0)] == client.ts().range( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 - ) + assert_resp_response(client, res, [(0, 5.0), (5, 6.0)], [[0, 5.0], [5, 6.0]]) + res = client.ts().range(1, 0, 10, aggregation_type="twa", bucket_size_msec=10) + assert_resp_response(client, res, [(0, 2.55), (10, 3.0)], [[0, 2.55], [10, 3.0]]) -@pytest.mark.redismod +@pytest.mark.onlynoncluster @skip_ifmodversion_lt("1.8.0", "timeseries") def test_range_latest(client: redis.Redis): timeseries = client.ts() @@ -249,17 +262,20 @@ def test_range_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - res = timeseries.range("t1", 0, 20) - assert res == [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)] - res = timeseries.range("t2", 0, 10) - assert res == [(0, 4.0)] + assert_resp_response( + client, + timeseries.range("t1", 0, 20), + [(1, 1.0), (2, 3.0), (11, 7.0), (13, 1.0)], + [[1, 1.0], [2, 3.0], [11, 7.0], [13, 1.0]], + ) + assert_resp_response(client, timeseries.range("t2", 0, 10), [(0, 4.0)], [[0, 4.0]]) res = timeseries.range("t2", 0, 10, latest=True) - assert res == [(0, 4.0), (10, 8.0)] - res = timeseries.range("t2", 0, 9, latest=True) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0), (10, 8.0)], [[0, 4.0], [10, 8.0]]) + assert_resp_response( + client, timeseries.range("t2", 0, 9, latest=True), [(0, 4.0)], [[0, 4.0]] + ) -@pytest.mark.redismod @skip_ifmodversion_lt("1.8.0", "timeseries") def test_range_bucket_timestamp(client: redis.Redis): timeseries = client.ts() @@ -269,21 +285,30 @@ def test_range_bucket_timestamp(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 - ) - assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( - "t1", - 0, - 100, - align=0, - aggregation_type="max", - bucket_size_msec=10, - bucket_timestamp="+", + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(10, 4.0), (50, 3.0), (70, 5.0)], + [[10, 4.0], [50, 3.0], [70, 5.0]], + ) + assert_resp_response( + client, + timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ), + [(20, 4.0), (60, 3.0), (80, 5.0)], + [[20, 4.0], [60, 3.0], [80, 5.0]], ) -@pytest.mark.redismod @skip_ifmodversion_lt("1.8.0", "timeseries") def test_range_empty(client: redis.Redis): timeseries = client.ts() @@ -293,8 +318,13 @@ def test_range_empty(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(10, 4.0), (50, 3.0), (70, 5.0)] == timeseries.range( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.range( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(10, 4.0), (50, 3.0), (70, 5.0)], + [[10, 4.0], [50, 3.0], [70, 5.0]], ) res = timeseries.range( "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True @@ -302,7 +332,7 @@ def test_range_empty(client: redis.Redis): for i in range(len(res)): if math.isnan(res[i][1]): res[i] = (res[i][0], None) - assert [ + resp2_expected = [ (10, 4.0), (20, None), (30, None), @@ -310,10 +340,19 @@ def test_range_empty(client: redis.Redis): (50, 3.0), (60, None), (70, 5.0), - ] == res + ] + resp3_expected = [ + [10, 4.0], + (20, None), + (30, None), + (40, None), + [50, 3.0], + (60, None), + [70, 5.0], + ] + assert_resp_response(client, res, resp2_expected, resp3_expected) -@pytest.mark.redismod @skip_ifmodversion_lt("99.99.99", "timeseries") def test_rev_range(client): for i in range(100): @@ -337,18 +376,31 @@ def test_rev_range(client): filter_by_max_value=2, ) ) - assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ), + [(10, 1.0), (0, 10.0)], + [[10, 1.0], [0, 10.0]], ) - assert [(1, 10.0), (0, 1.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + assert_resp_response( + client, + client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ), + [(1, 10.0), (0, 1.0)], + [[1, 10.0], [0, 1.0]], ) - assert [(10, 3.0), (0, 2.55)] == client.ts().revrange( - 1, 0, 10, aggregation_type="twa", bucket_size_msec=10 + assert_resp_response( + client, + client.ts().revrange(1, 0, 10, aggregation_type="twa", bucket_size_msec=10), + [(10, 3.0), (0, 2.55)], + [[10, 3.0], [0, 2.55]], ) -@pytest.mark.redismod +@pytest.mark.onlynoncluster @skip_ifmodversion_lt("1.8.0", "timeseries") def test_revrange_latest(client: redis.Redis): timeseries = client.ts() @@ -360,14 +412,13 @@ def test_revrange_latest(client: redis.Redis): timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) res = timeseries.revrange("t2", 0, 10) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0)], [[0, 4.0]]) res = timeseries.revrange("t2", 0, 10, latest=True) - assert res == [(10, 8.0), (0, 4.0)] + assert_resp_response(client, res, [(10, 8.0), (0, 4.0)], [[10, 8.0], [0, 4.0]]) res = timeseries.revrange("t2", 0, 9, latest=True) - assert res == [(0, 4.0)] + assert_resp_response(client, res, [(0, 4.0)], [[0, 4.0]]) -@pytest.mark.redismod @skip_ifmodversion_lt("1.8.0", "timeseries") def test_revrange_bucket_timestamp(client: redis.Redis): timeseries = client.ts() @@ -377,21 +428,30 @@ def test_revrange_bucket_timestamp(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 - ) - assert [(20, 4.0), (60, 3.0), (80, 5.0)] == timeseries.range( - "t1", - 0, - 100, - align=0, - aggregation_type="max", - bucket_size_msec=10, - bucket_timestamp="+", + assert_resp_response( + client, + timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(70, 5.0), (50, 3.0), (10, 4.0)], + [[70, 5.0], [50, 3.0], [10, 4.0]], + ) + assert_resp_response( + client, + timeseries.range( + "t1", + 0, + 100, + align=0, + aggregation_type="max", + bucket_size_msec=10, + bucket_timestamp="+", + ), + [(20, 4.0), (60, 3.0), (80, 5.0)], + [[20, 4.0], [60, 3.0], [80, 5.0]], ) -@pytest.mark.redismod @skip_ifmodversion_lt("1.8.0", "timeseries") def test_revrange_empty(client: redis.Redis): timeseries = client.ts() @@ -401,8 +461,13 @@ def test_revrange_empty(client: redis.Redis): timeseries.add("t1", 51, 3) timeseries.add("t1", 73, 5) timeseries.add("t1", 75, 3) - assert [(70, 5.0), (50, 3.0), (10, 4.0)] == timeseries.revrange( - "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + assert_resp_response( + client, + timeseries.revrange( + "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10 + ), + [(70, 5.0), (50, 3.0), (10, 4.0)], + [[70, 5.0], [50, 3.0], [10, 4.0]], ) res = timeseries.revrange( "t1", 0, 100, align=0, aggregation_type="max", bucket_size_msec=10, empty=True @@ -410,7 +475,7 @@ def test_revrange_empty(client: redis.Redis): for i in range(len(res)): if math.isnan(res[i][1]): res[i] = (res[i][0], None) - assert [ + resp2_expected = [ (70, 5.0), (60, None), (50, 3.0), @@ -418,10 +483,19 @@ def test_revrange_empty(client: redis.Redis): (30, None), (20, None), (10, 4.0), - ] == res + ] + resp3_expected = [ + [70, 5.0], + (60, None), + [50, 3.0], + (40, None), + (30, None), + (20, None), + [10, 4.0], + ] + assert_resp_response(client, res, resp2_expected, resp3_expected) -@pytest.mark.redismod @pytest.mark.onlynoncluster def test_mrange(client): client.ts().create(1, labels={"Test": "This", "team": "ny"}) @@ -432,26 +506,44 @@ def test_mrange(client): res = client.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 100 == len(res[0]["1"][1]) - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrange( - 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert 100 == len(res["1"][2]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res["1"][2]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res["1"][2]) - # test withlabels - assert {} == res[0]["1"][0] - res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + # test withlabels + assert {} == res["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res["1"][0] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") def test_multi_range_advanced(client): @@ -463,52 +555,108 @@ def test_multi_range_advanced(client): # test with selected labels res = client.ts().mrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(client): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - # test with filterby - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - # test groupby - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="sum") - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="Test", reduce="max") - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange(0, 3, filters=["Test=This"], groupby="team", reduce="min") - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + # test groupby + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - # test align - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(0, 5.0), (5, 6.0)] == res[0]["1"][1] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] + + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [[15, 1.0], [16, 2.0]] == res["1"][2] + + # test groupby + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [[0, 0.0], [1, 2.0], [2, 4.0], [3, 6.0]] == res["Test=This"][3] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["Test=This"][3] + res = client.ts().mrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=ny"][3] + assert [[0, 0.0], [1, 1.0], [2, 2.0], [3, 3.0]] == res["team=sf"][3] + + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [[0, 10.0], [10, 1.0]] == res["1"][2] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [[0, 5.0], [5, 6.0]] == res["1"][2] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("1.8.0", "timeseries") def test_mrange_latest(client: redis.Redis): @@ -527,13 +675,17 @@ def test_mrange_latest(client: redis.Redis): timeseries.add("t3", 2, 3) timeseries.add("t3", 11, 7) timeseries.add("t3", 13, 1) - assert client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True) == [ - {"t2": [{}, [(0, 4.0), (10, 8.0)]]}, - {"t4": [{}, [(0, 4.0), (10, 8.0)]]}, - ] + assert_resp_response( + client, + client.ts().mrange(0, 10, filters=["is_compaction=true"], latest=True), + [{"t2": [{}, [(0, 4.0), (10, 8.0)]]}, {"t4": [{}, [(0, 4.0), (10, 8.0)]]}], + { + "t2": [{}, {"aggregators": []}, [[0, 4.0], [10, 8.0]]], + "t4": [{}, {"aggregators": []}, [[0, 4.0], [10, 8.0]]], + }, + ) -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("99.99.99", "timeseries") def test_multi_reverse_range(client): @@ -545,10 +697,16 @@ def test_multi_reverse_range(client): res = client.ts().mrange(0, 200, filters=["Test=This"]) assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 100 == len(res[0]["1"][1]) + else: + assert 100 == len(res["1"][2]) res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) + if is_resp2_connection(client): + assert 10 == len(res[0]["1"][1]) + else: + assert 10 == len(res["1"][2]) for i in range(100): client.ts().add(1, i + 200, i % 7) @@ -556,17 +714,28 @@ def test_multi_reverse_range(client): 0, 500, filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] + if is_resp2_connection(client): + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] + else: + assert 20 == len(res["1"][2]) + assert {} == res["1"][0] # test withlabels res = client.ts().mrevrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + if is_resp2_connection(client): + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + else: + assert {"Test": "This", "team": "ny"} == res["1"][0] # test with selected labels res = client.ts().mrevrange(0, 200, filters=["Test=This"], select_labels=["team"]) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] + if is_resp2_connection(client): + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] + else: + assert {"team": "ny"} == res["1"][0] + assert {"team": "sf"} == res["2"][0] # test filterby res = client.ts().mrevrange( @@ -577,23 +746,36 @@ def test_multi_reverse_range(client): filter_by_min_value=1, filter_by_max_value=2, ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + else: + assert [[16, 2.0], [15, 1.0]] == res["1"][2] # test groupby res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + if is_resp2_connection(client): + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + else: + assert [[3, 6.0], [2, 4.0], [1, 2.0], [0, 0.0]] == res["Test=This"][3] res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="Test", reduce="max" ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + if is_resp2_connection(client): + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + else: + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["Test=This"][3] res = client.ts().mrevrange( 0, 3, filters=["Test=This"], groupby="team", reduce="min" ) assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + if is_resp2_connection(client): + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + else: + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=ny"][3] + assert [[3, 3.0], [2, 2.0], [1, 1.0], [0, 0.0]] == res["team=sf"][3] # test align res = client.ts().mrevrange( @@ -604,7 +786,10 @@ def test_multi_reverse_range(client): bucket_size_msec=10, align="-", ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + else: + assert [[10, 1.0], [0, 10.0]] == res["1"][2] res = client.ts().mrevrange( 0, 10, @@ -613,10 +798,12 @@ def test_multi_reverse_range(client): bucket_size_msec=10, align=1, ) - assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + if is_resp2_connection(client): + assert [(1, 10.0), (0, 1.0)] == res[0]["1"][1] + else: + assert [[1, 10.0], [0, 1.0]] == res["1"][2] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("1.8.0", "timeseries") def test_mrevrange_latest(client: redis.Redis): @@ -635,23 +822,28 @@ def test_mrevrange_latest(client: redis.Redis): timeseries.add("t3", 2, 3) timeseries.add("t3", 11, 7) timeseries.add("t3", 13, 1) - assert client.ts().mrevrange( - 0, 10, filters=["is_compaction=true"], latest=True - ) == [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}] + assert_resp_response( + client, + client.ts().mrevrange(0, 10, filters=["is_compaction=true"], latest=True), + [{"t2": [{}, [(10, 8.0), (0, 4.0)]]}, {"t4": [{}, [(10, 8.0), (0, 4.0)]]}], + { + "t2": [{}, {"aggregators": []}, [[10, 8.0], [0, 4.0]]], + "t4": [{}, {"aggregators": []}, [[10, 8.0], [0, 4.0]]], + }, + ) -@pytest.mark.redismod def test_get(client): name = "test" client.ts().create(name) - assert client.ts().get(name) is None + assert not client.ts().get(name) client.ts().add(name, 2, 3) assert 2 == client.ts().get(name)[0] client.ts().add(name, 3, 4) assert 4 == client.ts().get(name)[1] -@pytest.mark.redismod +@pytest.mark.onlynoncluster @skip_ifmodversion_lt("1.8.0", "timeseries") def test_get_latest(client: redis.Redis): timeseries = client.ts() @@ -662,33 +854,47 @@ def test_get_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - assert (0, 4.0) == timeseries.get("t2") - assert (10, 8.0) == timeseries.get("t2", latest=True) + assert_resp_response(client, timeseries.get("t2"), (0, 4.0), [0, 4.0]) + assert_resp_response( + client, timeseries.get("t2", latest=True), (10, 8.0), [10, 8.0] + ) -@pytest.mark.redismod @pytest.mark.onlynoncluster def test_mget(client): client.ts().create(1, labels={"Test": "This"}) client.ts().create(2, labels={"Test": "This", "Taste": "That"}) act_res = client.ts().mget(["Test=This"]) exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res + exp_res_resp3 = {"1": [{}, []], "2": [{}, []]} + assert_resp_response(client, act_res, exp_res, exp_res_resp3) client.ts().add(1, "*", 15) client.ts().add(2, "*", 25) res = client.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] + if is_resp2_connection(client): + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + else: + assert 15 == res["1"][1][1] + assert 25 == res["2"][1][1] res = client.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] + if is_resp2_connection(client): + assert 25 == res[0]["2"][2] + else: + assert 25 == res["2"][1][1] # test with_labels - assert {} == res[0]["2"][0] + if is_resp2_connection(client): + assert {} == res[0]["2"][0] + else: + assert {} == res["2"][0] res = client.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + if is_resp2_connection(client): + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + else: + assert {"Taste": "That", "Test": "This"} == res["2"][0] -@pytest.mark.redismod @pytest.mark.onlynoncluster @skip_ifmodversion_lt("1.8.0", "timeseries") def test_mget_latest(client: redis.Redis): @@ -700,43 +906,45 @@ def test_mget_latest(client: redis.Redis): timeseries.add("t1", 2, 3) timeseries.add("t1", 11, 7) timeseries.add("t1", 13, 1) - assert timeseries.mget(filters=["is_compaction=true"]) == [{"t2": [{}, 0, 4.0]}] - assert [{"t2": [{}, 10, 8.0]}] == timeseries.mget( - filters=["is_compaction=true"], latest=True - ) + res = timeseries.mget(filters=["is_compaction=true"]) + assert_resp_response(client, res, [{"t2": [{}, 0, 4.0]}], {"t2": [{}, [0, 4.0]]}) + res = timeseries.mget(filters=["is_compaction=true"], latest=True) + assert_resp_response(client, res, [{"t2": [{}, 10, 8.0]}], {"t2": [{}, [10, 8.0]]}) -@pytest.mark.redismod def test_info(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" + assert_resp_response( + client, 5, info.get("retention_msecs"), info.get("retentionTime") + ) + assert info["labels"]["currentLabel"] == "currentData" -@pytest.mark.redismod @skip_ifmodversion_lt("1.4.0", "timeseries") def testInfoDuplicatePolicy(client): client.ts().create(1, retention_msecs=5, labels={"currentLabel": "currentData"}) info = client.ts().info(1) - assert info.duplicate_policy is None + assert_resp_response( + client, None, info.get("duplicate_policy"), info.get("duplicatePolicy") + ) client.ts().create("time-serie-2", duplicate_policy="min") info = client.ts().info("time-serie-2") - assert "min" == info.duplicate_policy + assert_resp_response( + client, "min", info.get("duplicate_policy"), info.get("duplicatePolicy") + ) -@pytest.mark.redismod @pytest.mark.onlynoncluster def test_query_index(client): client.ts().create(1, labels={"Test": "This"}) client.ts().create(2, labels={"Test": "This", "Taste": "That"}) assert 2 == len(client.ts().queryindex(["Test=This"])) assert 1 == len(client.ts().queryindex(["Taste=That"])) - assert [2] == client.ts().queryindex(["Taste=That"]) + assert_resp_response(client, client.ts().queryindex(["Taste=That"]), [2], {"2"}) -@pytest.mark.redismod def test_pipeline(client): pipeline = client.ts().pipeline() pipeline.create("with_pipeline") @@ -745,15 +953,21 @@ def test_pipeline(client): pipeline.execute() info = client.ts().info("with_pipeline") - assert info.last_timestamp == 99 - assert info.total_samples == 100 + assert_resp_response( + client, 99, info.get("last_timestamp"), info.get("lastTimestamp") + ) + assert_resp_response( + client, 100, info.get("total_samples"), info.get("totalSamples") + ) assert client.ts().get("with_pipeline")[1] == 99 * 1.1 -@pytest.mark.redismod def test_uncompressed(client): client.ts().create("compressed") client.ts().create("uncompressed", uncompressed=True) compressed_info = client.ts().info("compressed") uncompressed_info = client.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage + if is_resp2_connection(client): + assert compressed_info.memory_usage != uncompressed_info.memory_usage + else: + assert compressed_info["memoryUsage"] != uncompressed_info["memoryUsage"] diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 553c77b3c6..0000000000 --- a/tox.ini +++ /dev/null @@ -1,379 +0,0 @@ -[pytest] -addopts = -s -markers = - redismod: run only the redis module tests - pipeline: pipeline tests - onlycluster: marks tests to be run only with cluster mode redis - onlynoncluster: marks tests to be run only with standalone redis - ssl: marker for only the ssl tests - asyncio: marker for async tests - replica: replica tests - experimental: run only experimental tests -asyncio_mode = auto - -[tox] -minversion = 3.2.0 -requires = tox-docker -envlist = {standalone,cluster}-{plain,hiredis,ocsp}-{uvloop,asyncio}-{py37,py38,py39,pypy3},linters,docs - -[docker:master] -name = master -image = redisfab/redis-py:6.2.6 -ports = - 6379:6379/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6379)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/master/redis.conf:/redis.conf - -[docker:replica] -name = replica -image = redisfab/redis-py:6.2.6 -links = - master:master -ports = - 6380:6380/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6380)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/replica/redis.conf:/redis.conf - -[docker:unstable] -name = unstable -image = redisfab/redis-py:unstable -ports = - 6378:6378/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6378)) else False" -volumes = - bind:rw:{toxinidir}/docker/unstable/redis.conf:/redis.conf - -[docker:unstable_cluster] -name = unstable_cluster -image = redisfab/redis-py-cluster:unstable -ports = - 6372:6372/tcp - 6373:6373/tcp - 6374:6374/tcp - 6375:6375/tcp - 6376:6376/tcp - 6377:6377/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(6372,6377)]) else False" -volumes = - bind:rw:{toxinidir}/docker/unstable_cluster/redis.conf:/redis.conf - -[docker:sentinel_1] -name = sentinel_1 -image = redisfab/redis-py-sentinel:6.2.6 -links = - master:master -ports = - 26379:26379/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26379)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/sentinel/sentinel_1.conf:/sentinel.conf - -[docker:sentinel_2] -name = sentinel_2 -image = redisfab/redis-py-sentinel:6.2.6 -links = - master:master -ports = - 26380:26380/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26380)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/sentinel/sentinel_2.conf:/sentinel.conf - -[docker:sentinel_3] -name = sentinel_3 -image = redisfab/redis-py-sentinel:6.2.6 -links = - master:master -ports = - 26381:26381/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26381)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis6.2/sentinel/sentinel_3.conf:/sentinel.conf - -[docker:redis_stack] -name = redis_stack -image = redis/redis-stack-server:edge -ports = - 36379:6379/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',36379)) else False" - -[docker:redis_cluster] -name = redis_cluster -image = redisfab/redis-py-cluster:6.2.6 -ports = - 16379:16379/tcp - 16380:16380/tcp - 16381:16381/tcp - 16382:16382/tcp - 16383:16383/tcp - 16384:16384/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16379,16384)]) else False" -volumes = - bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf - -[docker:redismod_cluster] -name = redismod_cluster -image = redisfab/redis-py-modcluster:edge -ports = - 46379:46379/tcp - 46380:46380/tcp - 46381:46381/tcp - 46382:46382/tcp - 46383:46383/tcp - 46384:46384/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(46379,46384)]) else False" -volumes = - bind:rw:{toxinidir}/docker/redismod_cluster/redis.conf:/redis.conf - -[docker:stunnel] -name = stunnel -image = redisfab/stunnel:latest -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6666)) else False" -links = - master:master -ports = - 6666:6666/tcp -volumes = - bind:ro:{toxinidir}/docker/stunnel/conf:/etc/stunnel/conf.d - bind:ro:{toxinidir}/docker/stunnel/keys:/etc/stunnel/keys - -[docker:redis5_master] -name = redis5_master -image = redisfab/redis-py:5.0-buster -ports = - 6382:6382/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6382)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/master/redis.conf:/redis.conf - -[docker:redis5_replica] -name = redis5_replica -image = redisfab/redis-py:5.0-buster -links = - redis5_master:redis5_master -ports = - 6383:6383/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6383)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/replica/redis.conf:/redis.conf - -[docker:redis5_sentinel_1] -name = redis5_sentinel_1 -image = redisfab/redis-py-sentinel:5.0-buster -links = - redis5_master:redis5_master -ports = - 26382:26382/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26382)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/sentinel/sentinel_1.conf:/sentinel.conf - -[docker:redis5_sentinel_2] -name = redis5_sentinel_2 -image = redisfab/redis-py-sentinel:5.0-buster -links = - redis5_master:redis5_master -ports = - 26383:26383/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26383)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/sentinel/sentinel_2.conf:/sentinel.conf - -[docker:redis5_sentinel_3] -name = redis5_sentinel_3 -image = redisfab/redis-py-sentinel:5.0-buster -links = - redis5_master:redis5_master -ports = - 26384:26384/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26384)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis5/sentinel/sentinel_3.conf:/sentinel.conf - -[docker:redis5_cluster] -name = redis5_cluster -image = redisfab/redis-py-cluster:5.0-buster -ports = - 16385:16385/tcp - 16386:16386/tcp - 16387:16387/tcp - 16388:16388/tcp - 16389:16389/tcp - 16390:16390/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16385,16390)]) else False" -volumes = - bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf - -[docker:redis4_master] -name = redis4_master -image = redisfab/redis-py:4.0-buster -ports = - 6381:6381/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',6381)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/master/redis.conf:/redis.conf - -[docker:redis4_sentinel_1] -name = redis4_sentinel_1 -image = redisfab/redis-py-sentinel:4.0-buster -links = - redis4_master:redis4_master -ports = - 26385:26385/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26385)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/sentinel/sentinel_1.conf:/sentinel.conf - -[docker:redis4_sentinel_2] -name = redis4_sentinel_2 -image = redisfab/redis-py-sentinel:4.0-buster -links = - redis4_master:redis4_master -ports = - 26386:26386/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26386)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/sentinel/sentinel_2.conf:/sentinel.conf - -[docker:redis4_sentinel_3] -name = redis4_sentinel_3 -image = redisfab/redis-py-sentinel:4.0-buster -links = - redis4_master:redis4_master -ports = - 26387:26387/tcp -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',26387)) else False" -volumes = - bind:rw:{toxinidir}/docker/redis4/sentinel/sentinel_3.conf:/sentinel.conf - -[docker:redis4_cluster] -name = redis4_cluster -image = redisfab/redis-py-cluster:4.0-buster -ports = - 16391:16391/tcp - 16392:16392/tcp - 16393:16393/tcp - 16394:16394/tcp - 16395:16395/tcp - 16396:16396/tcp -healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16391,16396)]) else False" -volumes = - bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf - -[isort] -profile = black -multi_line_output = 3 - -[testenv] -deps = - -r {toxinidir}/requirements.txt - -r {toxinidir}/dev_requirements.txt -docker = - unstable - unstable_cluster - master - replica - sentinel_1 - sentinel_2 - sentinel_3 - redis_cluster - redis_stack - stunnel -extras = - hiredis: hiredis - ocsp: cryptography, pyopenssl, requests -setenv = - CLUSTER_URL = "redis://localhost:16379/0" - UNSTABLE_CLUSTER_URL = "redis://localhost:6372/0" -commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --junit-xml=standalone-results.xml {posargs} - standalone-uvloop: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' --junit-xml=standalone-uvloop-results.xml --uvloop {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} --junit-xml=cluster-results.xml {posargs} - cluster-uvloop: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} --redis-unstable-url={env:UNSTABLE_CLUSTER_URL:} --junit-xml=cluster-uvloop-results.xml --uvloop {posargs} - -[testenv:redis5] -deps = - -r {toxinidir}/requirements.txt - -r {toxinidir}/dev_requirements.txt -docker = - redis5_master - redis5_replica - redis5_sentinel_1 - redis5_sentinel_2 - redis5_sentinel_3 - redis5_cluster -extras = - hiredis: hiredis - cryptography: cryptography, requests -setenv = - CLUSTER_URL = "redis://localhost:16385/0" -commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster and not redismod' {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} - -[testenv:redis4] -deps = - -r {toxinidir}/requirements.txt - -r {toxinidir}/dev_requirements.txt -docker = - redis4_master - redis4_sentinel_1 - redis4_sentinel_2 - redis4_sentinel_3 - redis4_cluster -extras = - hiredis: hiredis - cryptography: cryptography, requests -setenv = - CLUSTER_URL = "redis://localhost:16391/0" -commands = - standalone: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster and not redismod' {posargs} - cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} - -[testenv:devenv] -skipsdist = true -skip_install = true -deps = -r {toxinidir}/dev_requirements.txt -docker = {[testenv]docker} - -[testenv:linters] -deps_files = dev_requirements.txt -docker = -commands = - flake8 - black --target-version py37 --check --diff . - isort --check-only --diff . - vulture redis whitelist.py --min-confidence 80 - flynt --fail-on-change --dry-run . -skipsdist = true -skip_install = true - -[testenv:docs] -deps = -r docs/requirements.txt -docker = -changedir = {toxinidir}/docs -allowlist_externals = make -commands = make html - -[flake8] -max-line-length = 88 -exclude = - *.egg-info, - *.pyc, - .git, - .tox, - .venv*, - build, - docs/*, - dist, - docker, - venv*, - .venv*, - whitelist.py -ignore = - F405 - W503 - E203 - E126 diff --git a/whitelist.py b/whitelist.py index 8c9cee3c29..29cd529e4d 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,6 +14,5 @@ exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) -AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46)