Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fixtures and pytest.mark.parametrize with gen_cluster #4958

Merged
merged 4 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import pathlib
import socket
import threading
from contextlib import contextmanager
Expand Down Expand Up @@ -45,6 +46,47 @@ async def test_gen_cluster(c, s, a, b):
assert await c.submit(lambda: 123) == 123


@gen_cluster(client=True)
async def test_gen_cluster_pytest_fixture(c, s, a, b, tmp_path):
assert isinstance(tmp_path, pathlib.Path)
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)


@pytest.mark.parametrize("foo", [True])
@gen_cluster(client=True)
async def test_gen_cluster_parametrized(c, s, a, b, foo):
assert foo is True
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)


@pytest.mark.parametrize("foo", [True])
@pytest.mark.parametrize("bar", ["a", "b"])
@gen_cluster(client=True)
async def test_gen_cluster_multi_parametrized(c, s, a, b, foo, bar):
assert foo is True
assert bar in ("a", "b")
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)


@pytest.mark.parametrize("foo", [True])
@gen_cluster(client=True)
async def test_gen_cluster_parametrized_variadic_workers(c, s, *workers, foo):
assert foo is True
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in workers:
assert isinstance(w, Worker)


@gen_cluster(
client=True,
Worker=Nanny,
Expand Down
29 changes: 27 additions & 2 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import functools
import gc
import inspect
import io
import itertools
import logging
Expand Down Expand Up @@ -861,6 +862,15 @@ def gen_cluster(
async def test_foo(scheduler, worker1, worker2):
await ... # use tornado coroutines

@pytest.mark.parametrize("param", [1, 2, 3])
@gen_cluster()
async def test_foo(scheduler, worker1, worker2, param):
await ... # use tornado coroutines

@gen_cluster()
async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture_b):
await ... # use tornado coroutines

See also:
start
end
Expand All @@ -877,7 +887,7 @@ def _(func):
if not iscoroutinefunction(func):
func = gen.coroutine(func)

def test_func():
def test_func(*outer_args, **kwargs):
result = None
workers = []
with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop:
Expand Down Expand Up @@ -919,7 +929,7 @@ async def coro():
)
args = [c] + args
try:
future = func(*args)
future = func(*args, *outer_args, **kwargs)
if timeout:
future = asyncio.wait_for(future, timeout)
result = await future
Expand Down Expand Up @@ -979,6 +989,21 @@ def get_unclosed():

return result

# Patch the signature so pytest can inject fixtures
orig_sig = inspect.signature(func)
args = [None] * (1 + len(nthreads)) # scheduler, *workers
if client:
args.insert(0, None)

bound = orig_sig.bind_partial(*args)
test_func.__signature__ = orig_sig.replace(
parameters=[
p
for name, p in orig_sig.parameters.items()
if name not in bound.arguments
]
)

return test_func

return _
Expand Down