Skip to content

Commit

Permalink
Avoid mutable argument defaults outside of tests (#6665)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jul 28, 2022
1 parent 13f5a0c commit 9267a21
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 50 deletions.
34 changes: 21 additions & 13 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import Any, ClassVar, Literal
from typing import Any, ClassVar, Coroutine, Literal, Sequence, TypedDict

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand Down Expand Up @@ -117,9 +117,6 @@
"pubsub": PubSubClientExtension,
}

# Placeholder used in the get_dataset function(s)
NO_DEFAULT_PLACEHOLDER = "_no_default_"


def _get_global_client() -> Client | None:
L = sorted(list(_global_clients), reverse=True)
Expand Down Expand Up @@ -661,6 +658,12 @@ def _maybe_call_security_loader(address):
return None


class VersionsDict(TypedDict):
scheduler: dict[str, dict[str, Any]]
workers: dict[str, dict[str, dict[str, Any]]]
client: dict[str, dict[str, Any]]


class Client(SyncMethodMixin):
"""Connect to and submit computation to a Dask cluster
Expand Down Expand Up @@ -2554,18 +2557,18 @@ def list_datasets(self, **kwargs):
"""
return self.sync(self.scheduler.publish_list, **kwargs)

async def _get_dataset(self, name, default=NO_DEFAULT_PLACEHOLDER):
async def _get_dataset(self, name, default=no_default):
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

if out is None:
if default is NO_DEFAULT_PLACEHOLDER:
if default is no_default:
raise KeyError(f"Dataset '{name}' not found")
else:
return default
return out["data"]

def get_dataset(self, name, default=NO_DEFAULT_PLACEHOLDER, **kwargs):
def get_dataset(self, name, default=no_default, **kwargs):
"""
Get named dataset from the scheduler if present.
Return the default or raise a KeyError if not present.
Expand Down Expand Up @@ -4220,15 +4223,17 @@ def set_metadata(self, key, value):
key = (key,)
return self.sync(self.scheduler.set_metadata, keys=key, value=value)

def get_versions(self, check=False, packages=[]):
def get_versions(
self, check: bool = False, packages: Sequence[str] | None = None
) -> VersionsDict | Coroutine[Any, Any, VersionsDict]:
"""Return version info for the scheduler, all workers and myself
Parameters
----------
check : boolean, default False
check
raise ValueError if all required & optional packages
do not match
packages : List[str]
packages
Extra package names to check
Examples
Expand All @@ -4237,16 +4242,19 @@ def get_versions(self, check=False, packages=[]):
>>> c.get_versions(packages=['sklearn', 'geopandas']) # doctest: +SKIP
"""
return self.sync(self._get_versions, check=check, packages=packages)
return self.sync(self._get_versions, check=check, packages=packages or [])

async def _get_versions(self, check=False, packages=[]):
async def _get_versions(
self, check: bool = False, packages: Sequence[str] | None = None
) -> VersionsDict:
packages = packages or []
client = version_module.get_versions(packages=packages)
scheduler = await self.scheduler.versions(packages=packages)
workers = await self.scheduler.broadcast(
msg={"op": "versions", "packages": packages},
on_error="ignore",
)
result = {"scheduler": scheduler, "workers": workers, "client": client}
result = VersionsDict(scheduler=scheduler, workers=workers, client=client)

if check:
msg = version_module.error_message(scheduler, workers, client)
Expand Down
3 changes: 2 additions & 1 deletion distributed/dashboard/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
raise ImportError("Dask needs bokeh >= 2.1.1")


def BokehApplication(applications, server, prefix="/", template_variables={}):
def BokehApplication(applications, server, prefix="/", template_variables=None):
template_variables = template_variables or {}
prefix = "/" + prefix.strip("/") + "/" if prefix else "/"

extra = {"prefix": prefix, **template_variables}
Expand Down
22 changes: 13 additions & 9 deletions distributed/deploy/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ async def start(self):

def SSHCluster(
hosts: list[str] | None = None,
connect_options: dict | list[dict] = {},
worker_options: dict = {},
scheduler_options: dict = {},
connect_options: dict | list[dict] | None = None,
worker_options: dict | None = None,
scheduler_options: dict | None = None,
worker_module: str = "deprecated",
worker_class: str = "distributed.Nanny",
remote_python: str | list[str] | None = None,
Expand All @@ -327,22 +327,22 @@ def SSHCluster(
Parameters
----------
hosts : list[str]
hosts
List of hostnames or addresses on which to launch our cluster.
The first will be used for the scheduler and the rest for workers.
connect_options : dict or list of dict, optional
connect_options
Keywords to pass through to :func:`asyncssh.connect`.
This could include things such as ``port``, ``username``, ``password``
or ``known_hosts``. See docs for :func:`asyncssh.connect` and
:class:`asyncssh.SSHClientConnectionOptions` for full information.
If a list it must have the same length as ``hosts``.
worker_options : dict, optional
worker_options
Keywords to pass on to workers.
scheduler_options : dict, optional
scheduler_options
Keywords to pass on to scheduler.
worker_class: str
worker_class
The python class to use to create the worker(s).
remote_python : str or list of str, optional
remote_python
Path to Python on remote nodes.
Examples
Expand Down Expand Up @@ -393,6 +393,10 @@ def SSHCluster(
dask.distributed.Worker
asyncssh.connect
"""
connect_options = connect_options or {}
worker_options = worker_options or {}
scheduler_options = scheduler_options or {}

if worker_module != "deprecated":
raise ValueError(
"worker_module has been deprecated in favor of worker_class. "
Expand Down
7 changes: 3 additions & 4 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ class PipInstall(WorkerPlugin):
def __init__(self, packages, pip_options=None, restart=False):
self.packages = packages
self.restart = restart
if pip_options is None:
pip_options = []
self.pip_options = pip_options
self.pip_options = pip_options or []

async def setup(self, worker):
from distributed.lock import Lock
Expand Down Expand Up @@ -345,7 +343,8 @@ async def setup(self, worker):
class Environ(NannyPlugin):
restart = True

def __init__(self, environ={}):
def __init__(self, environ: dict | None = None):
environ = environ or {}
self.environ = {k: str(v) for k, v in environ.items()}

async def setup(self, nanny):
Expand Down
6 changes: 5 additions & 1 deletion distributed/http/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from tornado import web


def _descend_routes(router, routers=set(), out=set()):
def _descend_routes(router, routers=None, out=None):
if routers is None:
routers = set()
if out is None:
out = set()
if router in routers:
return
routers.add(router)
Expand Down
8 changes: 4 additions & 4 deletions distributed/multi_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ class MultiLock:
Parameters
----------
names: List[str]
names
Names of the locks to acquire. Choosing the same name allows two
disconnected processes to coordinate a lock.
client: Client (optional)
client
Client to use for communication with the scheduler. If not given, the
default global client will be used.
Expand All @@ -155,14 +155,14 @@ class MultiLock:
>>> lock.release() # doctest: +SKIP
"""

def __init__(self, names=[], client=None):
def __init__(self, names: list[str] | None = None, client: Client | None = None):
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client

self.names = names
self.names = names or []
self.id = uuid.uuid4().hex
self._locked = False

Expand Down
4 changes: 3 additions & 1 deletion distributed/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ class AsyncProcess:

_process: multiprocessing.Process

def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}):
def __init__(self, loop=None, target=None, name=None, args=(), kwargs=None):
kwargs = kwargs or {}

if not callable(target):
raise TypeError(f"`target` needs to be callable, not {type(target)!r}")
self._state = _ProcessState()
Expand Down
6 changes: 4 additions & 2 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dask

from distributed.utils import ensure_memoryview, nbytes
from distributed.utils import ensure_memoryview, nbytes, no_default

compressions: dict[
str | None | Literal[False],
Expand Down Expand Up @@ -152,7 +152,7 @@ def maybe_compress(
min_size=10_000,
sample_size=10_000,
nsamples=5,
compression=dask.config.get("distributed.comm.compression"),
compression=no_default,
):
"""
Maybe compress payload
Expand All @@ -164,6 +164,8 @@ def maybe_compress(
return the original
4. We return the compressed result
"""
if compression is no_default:
compression = dask.config.get("distributed.comm.compression")
if not compression:
return None, payload
if not (min_size <= nbytes(payload) <= 2**31):
Expand Down
21 changes: 21 additions & 0 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ def test_maybe_compress(lib, compression):
assert compressions[rc]["decompress"](rd) == payload


@pytest.mark.parametrize(
"lib,compression",
[(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")],
)
def test_maybe_compress_config_default(lib, compression):
if lib:
pytest.importorskip(lib)

try_converters = [bytes, memoryview]

with dask.config.set({"distributed.comm.compression": compression}):
for f in try_converters:
payload = b"123"
assert maybe_compress(f(payload)) == (None, payload)

payload = b"0" * 10000
rc, rd = maybe_compress(f(payload))
assert rc == compression
assert compressions[rc]["decompress"](rd) == payload


def test_maybe_compress_sample():
np = pytest.importorskip("numpy")
lz4 = pytest.importorskip("lz4")
Expand Down
46 changes: 31 additions & 15 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,12 +688,16 @@ def _close_queue(q):
def cluster(
nworkers=2,
nanny=False,
worker_kwargs={},
worker_kwargs=None,
active_rpc_timeout=10,
disconnect_timeout=20,
scheduler_kwargs={},
config={},
scheduler_kwargs=None,
config=None,
):
worker_kwargs = worker_kwargs or {}
scheduler_kwargs = scheduler_kwargs or {}
config = config or {}

ws = weakref.WeakSet()
enable_proctitle_on_children()

Expand Down Expand Up @@ -830,7 +834,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None):

def gen_test(
timeout: float = _TEST_TIMEOUT,
clean_kwargs: dict[str, Any] = {},
clean_kwargs: dict[str, Any] | None = None,
) -> Callable[[Callable], Callable]:
"""Coroutine test
Expand All @@ -844,6 +848,7 @@ async def test_foo(param)
async def test_foo():
await ... # use tornado coroutines
"""
clean_kwargs = clean_kwargs or {}
assert timeout, (
"timeout should always be set and it should be smaller than the global one from"
"pytest-timeout"
Expand Down Expand Up @@ -878,9 +883,12 @@ async def start_cluster(
scheduler_addr: str,
security: Security | dict[str, Any] | None = None,
Worker: type[ServerNode] = Worker,
scheduler_kwargs: dict[str, Any] = {},
worker_kwargs: dict[str, Any] = {},
scheduler_kwargs: dict[str, Any] | None = None,
worker_kwargs: dict[str, Any] | None = None,
) -> tuple[Scheduler, list[ServerNode]]:
scheduler_kwargs = scheduler_kwargs or {}
worker_kwargs = worker_kwargs or {}

s = await Scheduler(
validate=True,
security=security,
Expand Down Expand Up @@ -983,21 +991,18 @@ async def end_worker(w):


def gen_cluster(
nthreads: list[tuple[str, int] | tuple[str, int, dict]] = [
("127.0.0.1", 1),
("127.0.0.1", 2),
],
nthreads: list[tuple[str, int] | tuple[str, int, dict]] | None = None,
scheduler: str = "127.0.0.1",
timeout: float = _TEST_TIMEOUT,
security: Security | dict[str, Any] | None = None,
Worker: type[ServerNode] = Worker,
client: bool = False,
scheduler_kwargs: dict[str, Any] = {},
worker_kwargs: dict[str, Any] = {},
client_kwargs: dict[str, Any] = {},
scheduler_kwargs: dict[str, Any] | None = None,
worker_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
active_rpc_timeout: float = 1,
config: dict[str, Any] = {},
clean_kwargs: dict[str, Any] = {},
config: dict[str, Any] | None = None,
clean_kwargs: dict[str, Any] | None = None,
allow_unclosed: bool = False,
cluster_dump_directory: str | Literal[False] = "test_cluster_dump",
) -> Callable[[Callable], Callable]:
Expand All @@ -1022,6 +1027,17 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture
start
end
"""
if nthreads is None:
nthreads = [
("127.0.0.1", 1),
("127.0.0.1", 2),
]
scheduler_kwargs = scheduler_kwargs or {}
worker_kwargs = worker_kwargs or {}
client_kwargs = client_kwargs or {}
config = config or {}
clean_kwargs = clean_kwargs or {}

assert timeout, (
"timeout should always be set and it should be smaller than the global one from"
"pytest-timeout"
Expand Down

0 comments on commit 9267a21

Please sign in to comment.