diff --git a/distributed/client.py b/distributed/client.py index 115aa1a6be..6fbd1dc440 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -23,7 +23,7 @@ from functools import partial from numbers import Number from queue import Queue as pyQueue -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Coroutine, Literal, Sequence, TypedDict from tlz import first, groupby, keymap, merge, partition_all, valmap @@ -117,9 +117,6 @@ "pubsub": PubSubClientExtension, } -# Placeholder used in the get_dataset function(s) -NO_DEFAULT_PLACEHOLDER = "_no_default_" - def _get_global_client() -> Client | None: L = sorted(list(_global_clients), reverse=True) @@ -661,6 +658,12 @@ def _maybe_call_security_loader(address): return None +class VersionsDict(TypedDict): + scheduler: dict[str, dict[str, Any]] + workers: dict[str, dict[str, dict[str, Any]]] + client: dict[str, dict[str, Any]] + + class Client(SyncMethodMixin): """Connect to and submit computation to a Dask cluster @@ -2554,18 +2557,18 @@ def list_datasets(self, **kwargs): """ return self.sync(self.scheduler.publish_list, **kwargs) - async def _get_dataset(self, name, default=NO_DEFAULT_PLACEHOLDER): + async def _get_dataset(self, name, default=no_default): with self.as_current(): out = await self.scheduler.publish_get(name=name, client=self.id) if out is None: - if default is NO_DEFAULT_PLACEHOLDER: + if default is no_default: raise KeyError(f"Dataset '{name}' not found") else: return default return out["data"] - def get_dataset(self, name, default=NO_DEFAULT_PLACEHOLDER, **kwargs): + def get_dataset(self, name, default=no_default, **kwargs): """ Get named dataset from the scheduler if present. Return the default or raise a KeyError if not present. @@ -4220,15 +4223,17 @@ def set_metadata(self, key, value): key = (key,) return self.sync(self.scheduler.set_metadata, keys=key, value=value) - def get_versions(self, check=False, packages=[]): + def get_versions( + self, check: bool = False, packages: Sequence[str] | None = None + ) -> VersionsDict | Coroutine[Any, Any, VersionsDict]: """Return version info for the scheduler, all workers and myself Parameters ---------- - check : boolean, default False + check raise ValueError if all required & optional packages do not match - packages : List[str] + packages Extra package names to check Examples @@ -4237,16 +4242,19 @@ def get_versions(self, check=False, packages=[]): >>> c.get_versions(packages=['sklearn', 'geopandas']) # doctest: +SKIP """ - return self.sync(self._get_versions, check=check, packages=packages) + return self.sync(self._get_versions, check=check, packages=packages or []) - async def _get_versions(self, check=False, packages=[]): + async def _get_versions( + self, check: bool = False, packages: Sequence[str] | None = None + ) -> VersionsDict: + packages = packages or [] client = version_module.get_versions(packages=packages) scheduler = await self.scheduler.versions(packages=packages) workers = await self.scheduler.broadcast( msg={"op": "versions", "packages": packages}, on_error="ignore", ) - result = {"scheduler": scheduler, "workers": workers, "client": client} + result = VersionsDict(scheduler=scheduler, workers=workers, client=client) if check: msg = version_module.error_message(scheduler, workers, client) diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index 121a30f898..60d20ea983 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -20,7 +20,8 @@ raise ImportError("Dask needs bokeh >= 2.1.1") -def BokehApplication(applications, server, prefix="/", template_variables={}): +def BokehApplication(applications, server, prefix="/", template_variables=None): + template_variables = template_variables or {} prefix = "/" + prefix.strip("/") + "/" if prefix else "/" extra = {"prefix": prefix, **template_variables} diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 7d6f7dc686..8c4b472355 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -298,9 +298,9 @@ async def start(self): def SSHCluster( hosts: list[str] | None = None, - connect_options: dict | list[dict] = {}, - worker_options: dict = {}, - scheduler_options: dict = {}, + connect_options: dict | list[dict] | None = None, + worker_options: dict | None = None, + scheduler_options: dict | None = None, worker_module: str = "deprecated", worker_class: str = "distributed.Nanny", remote_python: str | list[str] | None = None, @@ -327,22 +327,22 @@ def SSHCluster( Parameters ---------- - hosts : list[str] + hosts List of hostnames or addresses on which to launch our cluster. The first will be used for the scheduler and the rest for workers. - connect_options : dict or list of dict, optional + connect_options Keywords to pass through to :func:`asyncssh.connect`. This could include things such as ``port``, ``username``, ``password`` or ``known_hosts``. See docs for :func:`asyncssh.connect` and :class:`asyncssh.SSHClientConnectionOptions` for full information. If a list it must have the same length as ``hosts``. - worker_options : dict, optional + worker_options Keywords to pass on to workers. - scheduler_options : dict, optional + scheduler_options Keywords to pass on to scheduler. - worker_class: str + worker_class The python class to use to create the worker(s). - remote_python : str or list of str, optional + remote_python Path to Python on remote nodes. Examples @@ -393,6 +393,10 @@ def SSHCluster( dask.distributed.Worker asyncssh.connect """ + connect_options = connect_options or {} + worker_options = worker_options or {} + scheduler_options = scheduler_options or {} + if worker_module != "deprecated": raise ValueError( "worker_module has been deprecated in favor of worker_class. " diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 8617e85a12..be9f469af6 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -276,9 +276,7 @@ class PipInstall(WorkerPlugin): def __init__(self, packages, pip_options=None, restart=False): self.packages = packages self.restart = restart - if pip_options is None: - pip_options = [] - self.pip_options = pip_options + self.pip_options = pip_options or [] async def setup(self, worker): from distributed.lock import Lock @@ -345,7 +343,8 @@ async def setup(self, worker): class Environ(NannyPlugin): restart = True - def __init__(self, environ={}): + def __init__(self, environ: dict | None = None): + environ = environ or {} self.environ = {k: str(v) for k, v in environ.items()} async def setup(self, nanny): diff --git a/distributed/http/routing.py b/distributed/http/routing.py index a9addae809..3304396437 100644 --- a/distributed/http/routing.py +++ b/distributed/http/routing.py @@ -7,7 +7,11 @@ from tornado import web -def _descend_routes(router, routers=set(), out=set()): +def _descend_routes(router, routers=None, out=None): + if routers is None: + routers = set() + if out is None: + out = set() if router in routers: return routers.add(router) diff --git a/distributed/multi_lock.py b/distributed/multi_lock.py index 533b35c3ae..e49fdb1f99 100644 --- a/distributed/multi_lock.py +++ b/distributed/multi_lock.py @@ -140,10 +140,10 @@ class MultiLock: Parameters ---------- - names: List[str] + names Names of the locks to acquire. Choosing the same name allows two disconnected processes to coordinate a lock. - client: Client (optional) + client Client to use for communication with the scheduler. If not given, the default global client will be used. @@ -155,14 +155,14 @@ class MultiLock: >>> lock.release() # doctest: +SKIP """ - def __init__(self, names=[], client=None): + def __init__(self, names: list[str] | None = None, client: Client | None = None): try: self.client = client or Client.current() except ValueError: # Initialise new client self.client = get_worker().client - self.names = names + self.names = names or [] self.id = uuid.uuid4().hex self._locked = False diff --git a/distributed/process.py b/distributed/process.py index c73c9bfe17..7f43b5beca 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -61,7 +61,9 @@ class AsyncProcess: _process: multiprocessing.Process - def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): + def __init__(self, loop=None, target=None, name=None, args=(), kwargs=None): + kwargs = kwargs or {} + if not callable(target): raise TypeError(f"`target` needs to be callable, not {type(target)!r}") self._state = _ProcessState() diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 5cbc21445c..a975e881ec 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -16,7 +16,7 @@ import dask -from distributed.utils import ensure_memoryview, nbytes +from distributed.utils import ensure_memoryview, nbytes, no_default compressions: dict[ str | None | Literal[False], @@ -152,7 +152,7 @@ def maybe_compress( min_size=10_000, sample_size=10_000, nsamples=5, - compression=dask.config.get("distributed.comm.compression"), + compression=no_default, ): """ Maybe compress payload @@ -164,6 +164,8 @@ def maybe_compress( return the original 4. We return the compressed result """ + if compression is no_default: + compression = dask.config.get("distributed.comm.compression") if not compression: return None, payload if not (min_size <= nbytes(payload) <= 2**31): diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index eded23ffc5..da14895eaa 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -121,6 +121,27 @@ def test_maybe_compress(lib, compression): assert compressions[rc]["decompress"](rd) == payload +@pytest.mark.parametrize( + "lib,compression", + [(None, None), ("zlib", "zlib"), ("lz4", "lz4"), ("zstandard", "zstd")], +) +def test_maybe_compress_config_default(lib, compression): + if lib: + pytest.importorskip(lib) + + try_converters = [bytes, memoryview] + + with dask.config.set({"distributed.comm.compression": compression}): + for f in try_converters: + payload = b"123" + assert maybe_compress(f(payload)) == (None, payload) + + payload = b"0" * 10000 + rc, rd = maybe_compress(f(payload)) + assert rc == compression + assert compressions[rc]["decompress"](rd) == payload + + def test_maybe_compress_sample(): np = pytest.importorskip("numpy") lz4 = pytest.importorskip("lz4") diff --git a/distributed/utils_test.py b/distributed/utils_test.py index fadc600237..42b52e53e4 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -688,12 +688,16 @@ def _close_queue(q): def cluster( nworkers=2, nanny=False, - worker_kwargs={}, + worker_kwargs=None, active_rpc_timeout=10, disconnect_timeout=20, - scheduler_kwargs={}, - config={}, + scheduler_kwargs=None, + config=None, ): + worker_kwargs = worker_kwargs or {} + scheduler_kwargs = scheduler_kwargs or {} + config = config or {} + ws = weakref.WeakSet() enable_proctitle_on_children() @@ -830,7 +834,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): def gen_test( timeout: float = _TEST_TIMEOUT, - clean_kwargs: dict[str, Any] = {}, + clean_kwargs: dict[str, Any] | None = None, ) -> Callable[[Callable], Callable]: """Coroutine test @@ -844,6 +848,7 @@ async def test_foo(param) async def test_foo(): await ... # use tornado coroutines """ + clean_kwargs = clean_kwargs or {} assert timeout, ( "timeout should always be set and it should be smaller than the global one from" "pytest-timeout" @@ -878,9 +883,12 @@ async def start_cluster( scheduler_addr: str, security: Security | dict[str, Any] | None = None, Worker: type[ServerNode] = Worker, - scheduler_kwargs: dict[str, Any] = {}, - worker_kwargs: dict[str, Any] = {}, + scheduler_kwargs: dict[str, Any] | None = None, + worker_kwargs: dict[str, Any] | None = None, ) -> tuple[Scheduler, list[ServerNode]]: + scheduler_kwargs = scheduler_kwargs or {} + worker_kwargs = worker_kwargs or {} + s = await Scheduler( validate=True, security=security, @@ -983,21 +991,18 @@ async def end_worker(w): def gen_cluster( - nthreads: list[tuple[str, int] | tuple[str, int, dict]] = [ - ("127.0.0.1", 1), - ("127.0.0.1", 2), - ], + nthreads: list[tuple[str, int] | tuple[str, int, dict]] | None = None, scheduler: str = "127.0.0.1", timeout: float = _TEST_TIMEOUT, security: Security | dict[str, Any] | None = None, Worker: type[ServerNode] = Worker, client: bool = False, - scheduler_kwargs: dict[str, Any] = {}, - worker_kwargs: dict[str, Any] = {}, - client_kwargs: dict[str, Any] = {}, + scheduler_kwargs: dict[str, Any] | None = None, + worker_kwargs: dict[str, Any] | None = None, + client_kwargs: dict[str, Any] | None = None, active_rpc_timeout: float = 1, - config: dict[str, Any] = {}, - clean_kwargs: dict[str, Any] = {}, + config: dict[str, Any] | None = None, + clean_kwargs: dict[str, Any] | None = None, allow_unclosed: bool = False, cluster_dump_directory: str | Literal[False] = "test_cluster_dump", ) -> Callable[[Callable], Callable]: @@ -1022,6 +1027,17 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture start end """ + if nthreads is None: + nthreads = [ + ("127.0.0.1", 1), + ("127.0.0.1", 2), + ] + scheduler_kwargs = scheduler_kwargs or {} + worker_kwargs = worker_kwargs or {} + client_kwargs = client_kwargs or {} + config = config or {} + clean_kwargs = clean_kwargs or {} + assert timeout, ( "timeout should always be set and it should be smaller than the global one from" "pytest-timeout"