diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 8166588d8e..04fbf62cf7 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,14 +1,9 @@ -import asyncio +import functools import random import sys from typing import Union from urllib.parse import urlparse -if sys.version_info[0:2] == (3, 6): - import pytest as pytest_asyncio -else: - import pytest_asyncio - import pytest from packaging.version import Version @@ -26,6 +21,13 @@ from .compat import mock +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio +else: + import pytest_asyncio + async def _get_info(redis_url): client = redis.Redis.from_url(redis_url) @@ -69,11 +71,13 @@ async def _get_info(redis_url): "pool-hiredis", ], ) -def create_redis(request, event_loop: asyncio.BaseEventLoop): +async def create_redis(request): """Wrapper around redis.create_redis.""" single_connection, parser_cls = request.param - async def f( + teardown_clients = [] + + async def client_factory( url: str = request.config.getoption("--redis-url"), cls=redis.Redis, flushdb=True, @@ -95,56 +99,50 @@ async def f( client = client.client() await client.initialize() - def teardown(): - async def ateardown(): - if not cluster_mode: - if "username" in kwargs: - return - if flushdb: - try: - await client.flushdb() - except redis.ConnectionError: - # handle cases where a test disconnected a client - # just manually retry the flushdb - await client.flushdb() - await client.close() - await client.connection_pool.disconnect() - else: - if flushdb: - try: - await client.flushdb(target_nodes="primaries") - except redis.ConnectionError: - # handle cases where a test disconnected a client - # just manually retry the flushdb - await client.flushdb(target_nodes="primaries") - await client.close() - - if event_loop.is_running(): - event_loop.create_task(ateardown()) + async def teardown(): + if not cluster_mode: + if flushdb and "username" not in kwargs: + try: + await client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() else: - event_loop.run_until_complete(ateardown()) - - request.addfinalizer(teardown) - + if flushdb: + try: + await client.flushdb(target_nodes="primaries") + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb(target_nodes="primaries") + await client.close() + + teardown_clients.append(teardown) return client - return f + yield client_factory + + for teardown in teardown_clients: + await teardown() @pytest_asyncio.fixture() -async def r(request, create_redis): - yield await create_redis() +async def r(create_redis): + return await create_redis() @pytest_asyncio.fixture() async def r2(create_redis): """A second client for tests that need multiple""" - yield await create_redis() + return await create_redis() @pytest_asyncio.fixture() async def modclient(request, create_redis): - yield await create_redis( + return await create_redis( url=request.config.getoption("--redismod-url"), decode_responses=True ) @@ -222,7 +220,7 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): def master_host(request): url = request.config.getoption("--redis-url") parts = urlparse(url) - yield parts.hostname + return parts.hostname async def wait_for_command( @@ -246,3 +244,41 @@ async def wait_for_command( return monitor_response if key in monitor_response["command"]: return None + + +# python 3.6 doesn't have the asynccontextmanager decorator. Provide it here. +class AsyncContextManager: + def __init__(self, async_generator): + self.gen = async_generator + + async def __aenter__(self): + try: + return await self.gen.__anext__() + except StopAsyncIteration as err: + raise RuntimeError("Pickles") from err + + async def __aexit__(self, exc_type, exc_inst, tb): + if exc_type: + await self.gen.athrow(exc_type, exc_inst, tb) + return True + try: + await self.gen.__anext__() + except StopAsyncIteration: + return + raise RuntimeError("More pickles") + + +if sys.version_info[0:2] == (3, 6): + + def asynccontextmanager(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return AsyncContextManager(func(*args, **kwargs)) + + return wrapper + +else: + from contextlib import asynccontextmanager as _asynccontextmanager + + def asynccontextmanager(func): + return _asynccontextmanager(func) diff --git a/tests/test_asyncio/test_bloom.py b/tests/test_asyncio/test_bloom.py index feb98cc41e..2bf4e030e6 100644 --- a/tests/test_asyncio/test_bloom.py +++ b/tests/test_asyncio/test_bloom.py @@ -1,10 +1,13 @@ +import sys + import pytest import redis.asyncio as redis from redis.exceptions import ModuleError, RedisError from redis.utils import HIREDIS_AVAILABLE -pytestmark = pytest.mark.asyncio +if sys.version_info[0:2] == (3, 6): + pytestmark = pytest.mark.asyncio def intlist(obj): @@ -91,7 +94,7 @@ async def do_verify(): res += rv == x assert res < 5 - do_verify() + await do_verify() cmds = [] if HIREDIS_AVAILABLE: with pytest.raises(ModuleError): @@ -120,7 +123,7 @@ async def do_verify(): cur_info = await modclient.bf().execute_command("bf.debug", "myBloom") assert prev_info == cur_info - do_verify() + await do_verify() await modclient.bf().client.delete("myBloom") await modclient.bf().create("myBloom", "0.0001", "10000000") diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 0d0ea33db2..8766cbf09b 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -11,6 +11,8 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio @@ -39,8 +41,6 @@ skip_unless_arch_bits, ) -pytestmark = pytest.mark.asyncio - default_host = "127.0.0.1" default_port = 7000 default_cluster_slots = [ diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index e128ac40b8..913f05b3fe 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -12,6 +12,8 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio @@ -27,11 +29,24 @@ REDIS_6_VERSION = "5.9.0" -pytestmark = pytest.mark.asyncio +@pytest_asyncio.fixture() +async def r_teardown(r: redis.Redis): + """ + A special fixture which removes the provided names from the database after use + """ + usernames = [] + + def factory(username): + usernames.append(username) + return r + + yield factory + for username in usernames: + await r.acl_deluser(username) @pytest_asyncio.fixture() -async def slowlog(r: redis.Redis, event_loop): +async def slowlog(r: redis.Redis): current_config = await r.config_get() old_slower_than_value = current_config["slowlog-log-slower-than"] old_max_legnth_value = current_config["slowlog-max-len"] @@ -94,17 +109,9 @@ async def test_acl_cat_with_category(self, r: redis.Redis): assert "get" in commands @skip_if_server_version_lt(REDIS_6_VERSION) - async def test_acl_deluser(self, r: redis.Redis, request, event_loop): + async def test_acl_deluser(self, r_teardown): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) + r = r_teardown(username) assert await r.acl_deluser(username) == 0 assert await r.acl_setuser(username, enabled=False, reset=True) @@ -117,18 +124,9 @@ async def test_acl_genpass(self, r: redis.Redis): @skip_if_server_version_lt(REDIS_6_VERSION) @skip_if_server_version_gte("7.0.0") - async def test_acl_getuser_setuser(self, r: redis.Redis, request, event_loop): + async def test_acl_getuser_setuser(self, r_teardown): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) - + r = r_teardown(username) # test enabled=False assert await r.acl_setuser(username, enabled=False, reset=True) assert await r.acl_getuser(username) == { @@ -233,17 +231,9 @@ def teardown(): @skip_if_server_version_lt(REDIS_6_VERSION) @skip_if_server_version_gte("7.0.0") - async def test_acl_list(self, r: redis.Redis, request, event_loop): + async def test_acl_list(self, r_teardown): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) + r = r_teardown(username) assert await r.acl_setuser(username, enabled=False, reset=True) users = await r.acl_list() @@ -251,17 +241,9 @@ def teardown(): @skip_if_server_version_lt(REDIS_6_VERSION) @pytest.mark.onlynoncluster - async def test_acl_log(self, r: redis.Redis, request, event_loop, create_redis): + async def test_acl_log(self, r_teardown, create_redis): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) + r = r_teardown(username) await r.acl_setuser( username, enabled=True, @@ -294,55 +276,25 @@ def teardown(): assert await r.acl_log_reset() @skip_if_server_version_lt(REDIS_6_VERSION) - async def test_acl_setuser_categories_without_prefix_fails( - self, r: redis.Redis, request, event_loop - ): + async def test_acl_setuser_categories_without_prefix_fails(self, r_teardown): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) + r = r_teardown(username) with pytest.raises(exceptions.DataError): await r.acl_setuser(username, categories=["list"]) @skip_if_server_version_lt(REDIS_6_VERSION) - async def test_acl_setuser_commands_without_prefix_fails( - self, r: redis.Redis, request, event_loop - ): + async def test_acl_setuser_commands_without_prefix_fails(self, r_teardown): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) + r = r_teardown(username) with pytest.raises(exceptions.DataError): await r.acl_setuser(username, commands=["get"]) @skip_if_server_version_lt(REDIS_6_VERSION) - async def test_acl_setuser_add_passwords_and_nopass_fails( - self, r: redis.Redis, request, event_loop - ): + async def test_acl_setuser_add_passwords_and_nopass_fails(self, r_teardown): username = "redis-py-user" - - def teardown(): - coro = r.acl_deluser(username) - if event_loop.is_running(): - event_loop.create_task(coro) - else: - event_loop.run_until_complete(coro) - - request.addfinalizer(teardown) + r = r_teardown(username) with pytest.raises(exceptions.DataError): await r.acl_setuser(username, passwords="+mypass", nopass=True) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 78a3efd2a0..8030f7e628 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -1,5 +1,6 @@ import asyncio import socket +import sys import types from unittest.mock import patch @@ -18,7 +19,8 @@ from .compat import mock -pytestmark = pytest.mark.asyncio +if sys.version_info[0:2] == (3, 6): + pytestmark = pytest.mark.asyncio @pytest.mark.onlynoncluster diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 6c56558d59..c8eb918e28 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -7,6 +7,8 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio @@ -15,10 +17,9 @@ from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt from .compat import mock +from .conftest import asynccontextmanager from .test_pubsub import wait_for_message -pytestmark = pytest.mark.asyncio - @pytest.mark.onlynoncluster class TestRedisAutoReleaseConnectionPool: @@ -114,7 +115,8 @@ async def can_read(self, timeout: float = 0): class TestConnectionPool: - def get_pool( + @asynccontextmanager + async def get_pool( self, connection_kwargs=None, max_connections=None, @@ -126,71 +128,77 @@ def get_pool( max_connections=max_connections, **connection_kwargs, ) - return pool + try: + yield pool + finally: + await pool.disconnect(inuse_connections=True) async def test_connection_creation(self): connection_kwargs = {"foo": "bar", "biz": "baz"} - pool = self.get_pool( + async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection - ) - connection = await pool.get_connection("_") - assert isinstance(connection, DummyConnection) - assert connection.kwargs == connection_kwargs + ) as pool: + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") - assert c1 != c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 async def test_max_connections(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) - await pool.get_connection("_") - await pool.get_connection("_") - with pytest.raises(redis.ConnectionError): + async with self.get_pool( + max_connections=2, connection_kwargs=connection_kwargs + ) as pool: + await pool.get_connection("_") await pool.get_connection("_") + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - await pool.release(c1) - c2 = await pool.get_connection("_") - assert c1 == c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 - def test_repr_contains_db_info_tcp(self): + async def test_repr_contains_db_info_tcp(self): connection_kwargs = { "host": "localhost", "port": 6379, "db": 1, "client_name": "test-client", } - pool = self.get_pool( + async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=redis.Connection - ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + ) as pool: + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected - def test_repr_contains_db_info_unix(self): + async def test_repr_contains_db_info_unix(self): connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} - pool = self.get_pool( + async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=redis.UnixDomainSocketConnection, - ) - expected = ( - "ConnectionPool>" - ) - assert repr(pool) == expected + ) as pool: + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected class TestBlockingConnectionPool: - def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + @asynccontextmanager + async def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} pool = redis.BlockingConnectionPool( connection_class=DummyConnection, @@ -198,7 +206,10 @@ def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): timeout=timeout, **connection_kwargs, ) - return pool + try: + yield pool + finally: + await pool.disconnect(inuse_connections=True) async def test_connection_creation(self, master_host): connection_kwargs = { @@ -207,10 +218,10 @@ async def test_connection_creation(self, master_host): "host": master_host[0], "port": master_host[1], } - pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = await pool.get_connection("_") - assert isinstance(connection, DummyConnection) - assert connection.kwargs == connection_kwargs + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs async def test_disconnect(self, master_host): """A regression test for #1047""" @@ -220,30 +231,31 @@ async def test_disconnect(self, master_host): "host": master_host[0], "port": master_host[1], } - pool = self.get_pool(connection_kwargs=connection_kwargs) - await pool.get_connection("_") - await pool.disconnect() + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + await pool.get_connection("_") + await pool.disconnect() async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") - assert c1 != c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 async def test_connection_pool_blocks_until_timeout(self, master_host): """When out of connections, block for timeout seconds, then raise""" connection_kwargs = {"host": master_host} - pool = self.get_pool( + async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs - ) - await pool.get_connection("_") + ) as pool: + c1 = await pool.get_connection("_") - start = asyncio.get_event_loop().time() - with pytest.raises(redis.ConnectionError): - await pool.get_connection("_") - # we should have waited at least 0.1 seconds - assert asyncio.get_event_loop().time() - start >= 0.1 + start = asyncio.get_event_loop().time() + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") + # we should have waited at least 0.1 seconds + assert asyncio.get_event_loop().time() - start >= 0.1 + await c1.disconnect() async def test_connection_pool_blocks_until_conn_available(self, master_host): """ @@ -251,26 +263,26 @@ async def test_connection_pool_blocks_until_conn_available(self, master_host): to the pool """ connection_kwargs = {"host": master_host[0], "port": master_host[1]} - pool = self.get_pool( + async with self.get_pool( max_connections=1, timeout=2, connection_kwargs=connection_kwargs - ) - c1 = await pool.get_connection("_") + ) as pool: + c1 = await pool.get_connection("_") - async def target(): - await asyncio.sleep(0.1) - await pool.release(c1) + async def target(): + await asyncio.sleep(0.1) + await pool.release(c1) - start = asyncio.get_event_loop().time() - await asyncio.gather(target(), pool.get_connection("_")) - assert asyncio.get_event_loop().time() - start >= 0.1 + start = asyncio.get_event_loop().time() + await asyncio.gather(target(), pool.get_connection("_")) + assert asyncio.get_event_loop().time() - start >= 0.1 async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - await pool.release(c1) - c2 = await pool.get_connection("_") - assert c1 == c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( @@ -689,6 +701,8 @@ async def test_arbitrary_command_advances_next_health_check(self, r): if r.connection: await r.get("foo") next_health_check = r.connection.next_health_check + # ensure that the event loop's `time()` advances a bit + await asyncio.sleep(0.001) await r.get("foo") assert next_health_check < r.connection.next_health_check diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py index 133ea3783c..5db7187c84 100644 --- a/tests/test_asyncio/test_encoding.py +++ b/tests/test_asyncio/test_encoding.py @@ -4,14 +4,14 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio import redis.asyncio as redis from redis.exceptions import DataError -pytestmark = pytest.mark.asyncio - @pytest.mark.onlynoncluster class TestEncoding: diff --git a/tests/test_asyncio/test_json.py b/tests/test_asyncio/test_json.py index a045dd7c1a..416a9f4a21 100644 --- a/tests/test_asyncio/test_json.py +++ b/tests/test_asyncio/test_json.py @@ -5,8 +5,6 @@ from redis.commands.json.path import Path from tests.conftest import skip_ifmodversion_lt -pytestmark = pytest.mark.asyncio - @pytest.mark.redismod async def test_json_setbinarykey(modclient: redis.Redis): diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 8ceb3bc958..86a8d62f71 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -5,14 +5,14 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio from redis.asyncio.lock import Lock from redis.exceptions import LockError, LockNotOwnedError -pytestmark = pytest.mark.asyncio - @pytest.mark.onlynoncluster class TestLock: diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py index 783ba262b0..9185bcd2ee 100644 --- a/tests/test_asyncio/test_monitor.py +++ b/tests/test_asyncio/test_monitor.py @@ -1,10 +1,13 @@ +import sys + import pytest from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise from .conftest import wait_for_command -pytestmark = pytest.mark.asyncio +if sys.version_info[0:2] == (3, 6): + pytestmark = pytest.mark.asyncio @pytest.mark.onlynoncluster diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index dfeb66464c..33391d019d 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -1,3 +1,5 @@ +import sys + import pytest import redis @@ -5,7 +7,8 @@ from .conftest import wait_for_command -pytestmark = pytest.mark.asyncio +if sys.version_info[0:2] == (3, 6): + pytestmark = pytest.mark.asyncio class TestPipeline: diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 6c76bf334e..d6a817a61b 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -8,6 +8,8 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio(forbid_global_loop=True) else: import pytest_asyncio @@ -18,8 +20,6 @@ from .compat import mock -pytestmark = pytest.mark.asyncio(forbid_global_loop=True) - def with_timeout(t): def wrapper(corofunc): @@ -80,6 +80,13 @@ def make_subscribe_test_data(pubsub, type): assert False, f"invalid subscribe type: {type}" +@pytest_asyncio.fixture() +async def pubsub(r: redis.Redis): + p = r.pubsub() + yield p + await p.close() + + @pytest.mark.onlynoncluster class TestPubSubSubscribeUnsubscribe: async def _test_subscribe_unsubscribe( @@ -101,12 +108,12 @@ async def _test_subscribe_unsubscribe( i = len(keys) - 1 - i assert await wait_for_message(p) == make_message(unsub_type, key, i) - async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "channel") + async def test_channel_subscribe_unsubscribe(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "channel") await self._test_subscribe_unsubscribe(**kwargs) - async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + async def test_pattern_subscribe_unsubscribe(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribe_unsubscribe(**kwargs) @pytest.mark.onlynoncluster @@ -144,12 +151,12 @@ async def _test_resubscribe_on_reconnection( for channel in unique_channels: assert channel in keys - async def test_resubscribe_to_channels_on_reconnection(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "channel") + async def test_resubscribe_to_channels_on_reconnection(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "channel") await self._test_resubscribe_on_reconnection(**kwargs) - async def test_resubscribe_to_patterns_on_reconnection(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + async def test_resubscribe_to_patterns_on_reconnection(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_resubscribe_on_reconnection(**kwargs) async def _test_subscribed_property( @@ -199,13 +206,13 @@ async def _test_subscribed_property( # now we're finally unsubscribed assert p.subscribed is False - async def test_subscribe_property_with_channels(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "channel") + async def test_subscribe_property_with_channels(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "channel") await self._test_subscribed_property(**kwargs) @pytest.mark.onlynoncluster - async def test_subscribe_property_with_patterns(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + async def test_subscribe_property_with_patterns(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribed_property(**kwargs) async def test_ignore_all_subscribe_messages(self, r: redis.Redis): @@ -224,9 +231,10 @@ async def test_ignore_all_subscribe_messages(self, r: redis.Redis): assert p.subscribed is True assert await wait_for_message(p) is None assert p.subscribed is False + await p.close() - async def test_ignore_individual_subscribe_messages(self, r: redis.Redis): - p = r.pubsub() + async def test_ignore_individual_subscribe_messages(self, pubsub): + p = pubsub checks = ( (p.subscribe, "foo"), @@ -243,13 +251,13 @@ async def test_ignore_individual_subscribe_messages(self, r: redis.Redis): assert message is None assert p.subscribed is False - async def test_sub_unsub_resub_channels(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "channel") + async def test_sub_unsub_resub_channels(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "channel") await self._test_sub_unsub_resub(**kwargs) @pytest.mark.onlynoncluster - async def test_sub_unsub_resub_patterns(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + async def test_sub_unsub_resub_patterns(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_sub_unsub_resub(**kwargs) async def _test_sub_unsub_resub( @@ -266,12 +274,12 @@ async def _test_sub_unsub_resub( assert await wait_for_message(p) == make_message(sub_type, key, 1) assert p.subscribed is True - async def test_sub_unsub_all_resub_channels(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "channel") + async def test_sub_unsub_all_resub_channels(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "channel") await self._test_sub_unsub_all_resub(**kwargs) - async def test_sub_unsub_all_resub_patterns(self, r: redis.Redis): - kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + async def test_sub_unsub_all_resub_patterns(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_sub_unsub_all_resub(**kwargs) async def _test_sub_unsub_all_resub( @@ -300,8 +308,8 @@ def message_handler(self, message): async def async_message_handler(self, message): self.async_message = message - async def test_published_message_to_channel(self, r: redis.Redis): - p = r.pubsub() + async def test_published_message_to_channel(self, r: redis.Redis, pubsub): + p = pubsub await p.subscribe("foo") assert await wait_for_message(p) == make_message("subscribe", "foo", 1) assert await r.publish("foo", "test message") == 1 @@ -310,8 +318,8 @@ async def test_published_message_to_channel(self, r: redis.Redis): assert isinstance(message, dict) assert message == make_message("message", "foo", "test message") - async def test_published_message_to_pattern(self, r: redis.Redis): - p = r.pubsub() + async def test_published_message_to_pattern(self, r: redis.Redis, pubsub): + p = pubsub await p.subscribe("foo") await p.psubscribe("f*") assert await wait_for_message(p) == make_message("subscribe", "foo", 1) @@ -340,6 +348,7 @@ async def test_channel_message_handler(self, r: redis.Redis): assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") + await p.close() async def test_channel_async_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -348,6 +357,7 @@ async def test_channel_async_message_handler(self, r): assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None assert self.async_message == make_message("message", "foo", "test message") + await p.close() async def test_channel_sync_async_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -359,6 +369,7 @@ async def test_channel_sync_async_message_handler(self, r): assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") assert self.async_message == make_message("message", "bar", "test message 2") + await p.close() @pytest.mark.onlynoncluster async def test_pattern_message_handler(self, r: redis.Redis): @@ -370,6 +381,7 @@ async def test_pattern_message_handler(self, r: redis.Redis): assert self.message == make_message( "pmessage", "foo", "test message", pattern="f*" ) + await p.close() async def test_unicode_channel_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -380,6 +392,7 @@ async def test_unicode_channel_message_handler(self, r: redis.Redis): assert await r.publish(channel, "test message") == 1 assert await wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") + await p.close() @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html @@ -395,9 +408,10 @@ async def test_unicode_pattern_message_handler(self, r: redis.Redis): assert self.message == make_message( "pmessage", channel, "test message", pattern=pattern ) + await p.close() - async def test_get_message_without_subscribe(self, r: redis.Redis): - p = r.pubsub() + async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): + p = pubsub with pytest.raises(RuntimeError) as info: await p.get_message() expect = ( @@ -427,8 +441,8 @@ def message_handler(self, message): async def r(self, create_redis): return await create_redis(decode_responses=True) - async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): - p = r.pubsub() + async def test_channel_subscribe_unsubscribe(self, pubsub): + p = pubsub await p.subscribe(self.channel) assert await wait_for_message(p) == self.make_message( "subscribe", self.channel, 1 @@ -439,8 +453,8 @@ async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): "unsubscribe", self.channel, 0 ) - async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): - p = r.pubsub() + async def test_pattern_subscribe_unsubscribe(self, pubsub): + p = pubsub await p.psubscribe(self.pattern) assert await wait_for_message(p) == self.make_message( "psubscribe", self.pattern, 1 @@ -451,8 +465,8 @@ async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): "punsubscribe", self.pattern, 0 ) - async def test_channel_publish(self, r: redis.Redis): - p = r.pubsub() + async def test_channel_publish(self, r: redis.Redis, pubsub): + p = pubsub await p.subscribe(self.channel) assert await wait_for_message(p) == self.make_message( "subscribe", self.channel, 1 @@ -463,8 +477,8 @@ async def test_channel_publish(self, r: redis.Redis): ) @pytest.mark.onlynoncluster - async def test_pattern_publish(self, r: redis.Redis): - p = r.pubsub() + async def test_pattern_publish(self, r: redis.Redis, pubsub): + p = pubsub await p.psubscribe(self.pattern) assert await wait_for_message(p) == self.make_message( "psubscribe", self.pattern, 1 @@ -490,6 +504,7 @@ async def test_channel_message_handler(self, r: redis.Redis): await r.publish(self.channel, new_data) assert await wait_for_message(p) is None assert self.message == self.make_message("message", self.channel, new_data) + await p.close() async def test_pattern_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -511,6 +526,7 @@ async def test_pattern_message_handler(self, r: redis.Redis): assert self.message == self.make_message( "pmessage", self.channel, new_data, pattern=self.pattern ) + await p.close() async def test_context_manager(self, r: redis.Redis): async with r.pubsub() as pubsub: @@ -520,6 +536,7 @@ async def test_context_manager(self, r: redis.Redis): assert pubsub.connection is None assert pubsub.channels == {} assert pubsub.patterns == {} + await pubsub.close() @pytest.mark.onlynoncluster @@ -535,8 +552,8 @@ async def test_channel_subscribe(self, r: redis.Redis): class TestPubSubSubcommands: @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.0") - async def test_pubsub_channels(self, r: redis.Redis): - p = r.pubsub() + async def test_pubsub_channels(self, r: redis.Redis, pubsub): + p = pubsub await p.subscribe("foo", "bar", "baz", "quux") for i in range(4): assert (await wait_for_message(p))["type"] == "subscribe" @@ -560,6 +577,9 @@ async def test_pubsub_numsub(self, r: redis.Redis): channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] assert await r.pubsub_numsub("foo", "bar", "baz") == channels + await p1.close() + await p2.close() + await p3.close() @skip_if_server_version_lt("2.8.0") async def test_pubsub_numpat(self, r: redis.Redis): @@ -568,6 +588,7 @@ async def test_pubsub_numpat(self, r: redis.Redis): for i in range(3): assert (await wait_for_message(p))["type"] == "psubscribe" assert await r.pubsub_numpat() == 3 + await p.close() @pytest.mark.onlynoncluster @@ -580,6 +601,7 @@ async def test_send_pubsub_ping(self, r: redis.Redis): assert await wait_for_message(p) == make_message( type="pong", channel=None, data="", pattern=None ) + await p.close() @skip_if_server_version_lt("3.0.0") async def test_send_pubsub_ping_message(self, r: redis.Redis): @@ -589,13 +611,16 @@ async def test_send_pubsub_ping_message(self, r: redis.Redis): assert await wait_for_message(p) == make_message( type="pong", channel=None, data="hello world", pattern=None ) + await p.close() @pytest.mark.onlynoncluster class TestPubSubConnectionKilled: @skip_if_server_version_lt("3.0.0") - async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis): - p = r.pubsub() + async def test_connection_error_raised_when_connection_dies( + self, r: redis.Redis, pubsub + ): + p = pubsub await p.subscribe("foo") assert await wait_for_message(p) == make_message("subscribe", "foo", 1) for client in await r.client_list(): @@ -607,8 +632,8 @@ async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis @pytest.mark.onlynoncluster class TestPubSubTimeouts: - async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): - p = r.pubsub() + async def test_get_message_with_timeout_returns_none(self, pubsub): + p = pubsub await p.subscribe("foo") assert await wait_for_message(p) == make_message("subscribe", "foo", 1) assert await p.get_message(timeout=0.01) is None @@ -616,15 +641,13 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): @pytest.mark.onlynoncluster class TestPubSubReconnect: - # @pytest.mark.xfail @with_timeout(2) - async def test_reconnect_listen(self, r: redis.Redis): + async def test_reconnect_listen(self, r: redis.Redis, pubsub): """ Test that a loop processing PubSub messages can survive a disconnect, by issuing a connect() call. """ messages = asyncio.Queue() - pubsub = r.pubsub() interrupt = False async def loop(): @@ -698,12 +721,12 @@ async def _subscribe(self, p, *args, **kwargs): ): return - async def test_callbacks(self, r: redis.Redis): + async def test_callbacks(self, r: redis.Redis, pubsub): def callback(message): messages.put_nowait(message) messages = asyncio.Queue() - p = r.pubsub() + p = pubsub await self._subscribe(p, foo=callback) task = asyncio.get_event_loop().create_task(p.run()) await r.publish("foo", "bar") @@ -720,13 +743,13 @@ def callback(message): "type": "message", } - async def test_exception_handler(self, r: redis.Redis): + async def test_exception_handler(self, r: redis.Redis, pubsub): def exception_handler_callback(e, pubsub) -> None: assert pubsub == p exceptions.put_nowait(e) exceptions = asyncio.Queue() - p = r.pubsub() + p = pubsub await self._subscribe(p, foo=lambda x: None) with mock.patch.object(p, "get_message", side_effect=Exception("error")): task = asyncio.get_event_loop().create_task( @@ -740,26 +763,25 @@ def exception_handler_callback(e, pubsub) -> None: pass assert str(e) == "error" - async def test_late_subscribe(self, r: redis.Redis): + async def test_late_subscribe(self, r: redis.Redis, pubsub): def callback(message): messages.put_nowait(message) messages = asyncio.Queue() - p = r.pubsub() + p = pubsub task = asyncio.get_event_loop().create_task(p.run()) # wait until loop gets settled. Add a subscription await asyncio.sleep(0.1) await p.subscribe(foo=callback) # wait tof the subscribe to finish. Cannot use _subscribe() because # p.run() is already accepting messages - await asyncio.sleep(0.1) - await r.publish("foo", "bar") - message = None - try: - async with async_timeout.timeout(0.1): - message = await messages.get() - except asyncio.TimeoutError: - pass + while True: + n = await r.publish("foo", "bar") + if n == 1: + break + await asyncio.sleep(0.1) + async with async_timeout.timeout(0.1): + message = await messages.get() task.cancel() # we expect a cancelled error, not the Runtime error # ("did you forget to call subscribe()"") diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index 764525fb4a..406ab208e2 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -4,6 +4,8 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio diff --git a/tests/test_asyncio/test_search.py b/tests/test_asyncio/test_search.py index 5aaa56f159..bc3a212ac9 100644 --- a/tests/test_asyncio/test_search.py +++ b/tests/test_asyncio/test_search.py @@ -1,6 +1,7 @@ import bz2 import csv import os +import sys import time from io import TextIOWrapper @@ -18,7 +19,8 @@ from redis.commands.search.suggestion import Suggestion from tests.conftest import skip_ifmodversion_lt -pytestmark = pytest.mark.asyncio +if sys.version_info[0:2] == (3, 6): + pytestmark = pytest.mark.asyncio WILL_PLAY_TEXT = os.path.abspath( diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 4130e67400..e77e07f98e 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -5,6 +5,8 @@ if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio @@ -17,8 +19,6 @@ SlaveNotFoundError, ) -pytestmark = pytest.mark.asyncio - @pytest_asyncio.fixture(scope="module") def master_ip(master_host): diff --git a/tests/test_asyncio/test_timeseries.py b/tests/test_asyncio/test_timeseries.py index ac2807fe1d..0e57c4f049 100644 --- a/tests/test_asyncio/test_timeseries.py +++ b/tests/test_asyncio/test_timeseries.py @@ -1,3 +1,4 @@ +import sys import time from time import sleep @@ -6,7 +7,8 @@ import redis.asyncio as redis from tests.conftest import skip_ifmodversion_lt -pytestmark = pytest.mark.asyncio +if sys.version_info[0:2] == (3, 6): + pytestmark = pytest.mark.asyncio @pytest.mark.redismod diff --git a/tests/test_json.py b/tests/test_json.py index 1cc448c5f9..0965a93d88 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1411,7 +1411,8 @@ def test_set_path(client): with open(jsonfile, "w+") as fp: fp.write(json.dumps({"hello": "world"})) - open(nojsonfile, "a+").write("hello") + with open(nojsonfile, "a+") as fp: + fp.write("hello") result = {jsonfile: True, nojsonfile: False} assert client.json().set_path(Path.root_path(), root) == result diff --git a/tests/test_ssl.py b/tests/test_ssl.py index d029b80dcb..ed38a3166b 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -68,8 +68,8 @@ def test_validating_self_signed_certificate(self, request): assert r.ping() def test_validating_self_signed_string_certificate(self, request): - f = open(self.SERVER_CERT) - cert_data = f.read() + with open(self.SERVER_CERT) as f: + cert_data = f.read() ssl_url = request.config.option.redis_ssl_url p = urlparse(ssl_url)[1].split(":") r = redis.Redis( diff --git a/tox.ini b/tox.ini index 0ceb008cf6..d1aeb02ade 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ markers = asyncio: marker for async tests replica: replica tests experimental: run only experimental tests +asyncio_mode = auto [tox] minversion = 3.2.0