Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warnings and resource usage problems in asyncio unittests #2258

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 80 additions & 44 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import asyncio
import functools
import random
import sys
from typing import Union
from urllib.parse import urlparse

if sys.version_info[0:2] == (3, 6):
import pytest as pytest_asyncio
else:
import pytest_asyncio

import pytest
from packaging.version import Version

Expand All @@ -26,6 +21,13 @@

from .compat import mock

if sys.version_info[0:2] == (3, 6):
import pytest as pytest_asyncio

pytestmark = pytest.mark.asyncio
else:
import pytest_asyncio


async def _get_info(redis_url):
client = redis.Redis.from_url(redis_url)
Expand Down Expand Up @@ -69,11 +71,13 @@ async def _get_info(redis_url):
"pool-hiredis",
],
)
def create_redis(request, event_loop: asyncio.BaseEventLoop):
async def create_redis(request):
"""Wrapper around redis.create_redis."""
single_connection, parser_cls = request.param

async def f(
teardown_clients = []

async def client_factory(
url: str = request.config.getoption("--redis-url"),
cls=redis.Redis,
flushdb=True,
Expand All @@ -95,56 +99,50 @@ async def f(
client = client.client()
await client.initialize()

def teardown():
async def ateardown():
if not cluster_mode:
if "username" in kwargs:
return
if flushdb:
try:
await client.flushdb()
except redis.ConnectionError:
# handle cases where a test disconnected a client
# just manually retry the flushdb
await client.flushdb()
await client.close()
await client.connection_pool.disconnect()
else:
if flushdb:
try:
await client.flushdb(target_nodes="primaries")
except redis.ConnectionError:
# handle cases where a test disconnected a client
# just manually retry the flushdb
await client.flushdb(target_nodes="primaries")
await client.close()

if event_loop.is_running():
event_loop.create_task(ateardown())
async def teardown():
if not cluster_mode:
if flushdb and "username" not in kwargs:
try:
await client.flushdb()
except redis.ConnectionError:
# handle cases where a test disconnected a client
# just manually retry the flushdb
await client.flushdb()
await client.close()
await client.connection_pool.disconnect()
else:
event_loop.run_until_complete(ateardown())

request.addfinalizer(teardown)

if flushdb:
try:
await client.flushdb(target_nodes="primaries")
except redis.ConnectionError:
# handle cases where a test disconnected a client
# just manually retry the flushdb
await client.flushdb(target_nodes="primaries")
await client.close()

teardown_clients.append(teardown)
return client

return f
yield client_factory

for teardown in teardown_clients:
await teardown()


@pytest_asyncio.fixture()
async def r(request, create_redis):
yield await create_redis()
async def r(create_redis):
return await create_redis()


@pytest_asyncio.fixture()
async def r2(create_redis):
"""A second client for tests that need multiple"""
yield await create_redis()
return await create_redis()


@pytest_asyncio.fixture()
async def modclient(request, create_redis):
yield await create_redis(
return await create_redis(
url=request.config.getoption("--redismod-url"), decode_responses=True
)

Expand Down Expand Up @@ -222,7 +220,7 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs):
def master_host(request):
url = request.config.getoption("--redis-url")
parts = urlparse(url)
yield parts.hostname
return parts.hostname


async def wait_for_command(
Expand All @@ -246,3 +244,41 @@ async def wait_for_command(
return monitor_response
if key in monitor_response["command"]:
return None


# python 3.6 doesn't have the asynccontextmanager decorator. Provide it here.
class AsyncContextManager:
def __init__(self, async_generator):
self.gen = async_generator

async def __aenter__(self):
try:
return await self.gen.__anext__()
except StopAsyncIteration as err:
raise RuntimeError("Pickles") from err

async def __aexit__(self, exc_type, exc_inst, tb):
if exc_type:
await self.gen.athrow(exc_type, exc_inst, tb)
return True
try:
await self.gen.__anext__()
except StopAsyncIteration:
return
raise RuntimeError("More pickles")


if sys.version_info[0:2] == (3, 6):

def asynccontextmanager(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return AsyncContextManager(func(*args, **kwargs))

return wrapper

else:
from contextlib import asynccontextmanager as _asynccontextmanager

def asynccontextmanager(func):
return _asynccontextmanager(func)
9 changes: 6 additions & 3 deletions tests/test_asyncio/test_bloom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import sys

import pytest

import redis.asyncio as redis
from redis.exceptions import ModuleError, RedisError
from redis.utils import HIREDIS_AVAILABLE

pytestmark = pytest.mark.asyncio
if sys.version_info[0:2] == (3, 6):
pytestmark = pytest.mark.asyncio


def intlist(obj):
Expand Down Expand Up @@ -91,7 +94,7 @@ async def do_verify():
res += rv == x
assert res < 5

do_verify()
await do_verify()
cmds = []
if HIREDIS_AVAILABLE:
with pytest.raises(ModuleError):
Expand Down Expand Up @@ -120,7 +123,7 @@ async def do_verify():

cur_info = await modclient.bf().execute_command("bf.debug", "myBloom")
assert prev_info == cur_info
do_verify()
await do_verify()

await modclient.bf().client.delete("myBloom")
await modclient.bf().create("myBloom", "0.0001", "10000000")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

if sys.version_info[0:2] == (3, 6):
import pytest as pytest_asyncio

pytestmark = pytest.mark.asyncio
else:
import pytest_asyncio

Expand Down Expand Up @@ -39,8 +41,6 @@
skip_unless_arch_bits,
)

pytestmark = pytest.mark.asyncio

default_host = "127.0.0.1"
default_port = 7000
default_cluster_slots = [
Expand Down
Loading