Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client method to dump cluster state #5470

Merged
merged 13 commits into from
Nov 10, 2021
9 changes: 9 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,12 @@ jobs:
with:
name: ${{ env.TEST_ID }}
path: reports
- name: Upload timeout reports
# ensure this runs even if pytest fails
if: >
always() &&
(steps.run_tests.outcome == 'success' || steps.run_tests.outcome == 'failure')
uses: actions/upload-artifact@v2
with:
name: ${{ env.TEST_ID }}-timeouts
path: test_timeout_dump
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ dask-worker-space/
.ycm_extra_conf.py
tags
.ipynb_checkpoints
.venv/
.venv/

# Test timeouts will dump the cluster state in here
test_timeout_dump/
112 changes: 111 additions & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import ClassVar
from typing import TYPE_CHECKING, Awaitable, ClassVar, Sequence

from tlz import first, groupby, keymap, merge, partition_all, valmap

if TYPE_CHECKING:
from typing_extensions import Literal

import dask
from dask.base import collections_to_dsk, normalize_token, tokenize
from dask.core import flatten
Expand Down Expand Up @@ -3509,6 +3512,113 @@ def scheduler_info(self, **kwargs):
self.sync(self._update_scheduler_info)
return self._scheduler_identity

async def _dump_cluster_state(
self,
filename: str,
exclude: Sequence[str] = None,
format: Literal["msgpack"] | Literal["yaml"] = "msgpack",
) -> None:

scheduler_info = self.scheduler.dump_state()

worker_info = self.scheduler.broadcast(
msg=dict(
op="dump_state",
exclude=exclude,
),
)
versions = self._get_versions()
scheduler_info, worker_info, versions_info = await asyncio.gather(
scheduler_info, worker_info, versions
)

state = {
"scheduler": scheduler_info,
"workers": worker_info,
"versions": versions_info,
}
filename = str(filename)
if format == "msgpack":
suffix = ".msgpack.gz"
if not filename.endswith(suffix):
filename += suffix
import gzip

import msgpack
import yaml

with gzip.open(filename, "wb") as fdg:
msgpack.pack(state, fdg)
elif format == "yaml":
suffix = ".yaml"
if not filename.endswith(suffix):
filename += suffix
import yaml

with open(filename, "w") as fd:
yaml.dump(state, fd)
else:
raise ValueError(
f"Unsupported format {format}. Possible values are `msgpack` or `yaml`"
)

def dump_cluster_state(
self,
filename: str = "dask-cluster-dump",
exclude: Sequence[str] = None,
format: Literal["msgpack"] | Literal["yaml"] = "msgpack",
) -> Awaitable | None:
"""Extract a dump of the entire cluster state and persist to disk.
This is intended for debugging purposes only.

Warning: Memory usage on client side can be large.

Results will be stored in a dict::

{
"scheduler_info": {...},
"worker_info": {
worker_addr: {...}, # worker attributes
...
}
}

Paramters
---------
filename:
The output filename. The appropriate file suffix (`.msgpack.gz` or
`.yaml`) will be appended automatically.
exclude:
A sequence of attribute names which are supposed to be blacklisted
from the dump, e.g. to exclude code, tracebacks, logs, etc.
format:
Either msgpack or yaml. If msgpack is used (default), the output
will be stored in a gzipped file as msgpack.

To read::

import gzip, msgpack
with gzip.open("filename") as fd:
state = msgpack.unpack(fd)

or::

import yaml
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
with open("filename") as fd:
state = yaml.load(fd, Loader=Loader)

"""
return self.sync(
self._dump_cluster_state,
filename=filename,
format=format,
exclude=exclude,
)

def write_scheduler_file(self, scheduler_file):
"""Write the scheduler information to a json file.

Expand Down
36 changes: 34 additions & 2 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from contextlib import suppress
from enum import Enum
from functools import partial
from typing import ClassVar
from typing import ClassVar, Container

import tblib
from tlz import merge
Expand All @@ -22,8 +22,11 @@
import dask
from dask.utils import parse_timedelta

from distributed.utils import recursive_to_dict

from . import profile, protocol
from .comm import (
Comm,
CommClosedError,
connect,
get_address_host_port,
Expand Down Expand Up @@ -147,6 +150,7 @@ def __init__(
"identity": self.identity,
"echo": self.echo,
"connection_stream": self.handle_stream,
"dump_state": self._to_dict,
}
self.handlers.update(handlers)
if blocked_handlers is None:
Expand Down Expand Up @@ -378,9 +382,37 @@ def port(self):
_, self._port = get_address_host_port(self.address)
return self._port

def identity(self, comm=None):
def identity(self, comm=None) -> dict[str, str]:
return {"type": type(self).__name__, "id": self.id}

def _to_dict(
self, comm: Comm = None, *, exclude: Container[str] = None
) -> dict[str, str]:
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.

Parameters
----------
comm:
exclude:
A list of attributes which must not be present in the output.

See also
--------
Server.identity
Client.dump_cluster_state
"""

info = self.identity()
extra = {
"address": self.address,
"status": self.status.name,
"thread_id": self.thread_id,
}
info.update(extra)
return recursive_to_dict(info, exclude=exclude)

def echo(self, comm=None, data=None):
return data

Expand Down
62 changes: 61 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from datetime import timedelta
from functools import partial
from numbers import Number
from typing import ClassVar
from typing import Any, ClassVar, Container
from typing import cast as pep484_cast

import psutil
Expand All @@ -49,11 +49,14 @@
from dask.utils import format_bytes, format_time, parse_bytes, parse_timedelta, tmpfile
from dask.widgets import get_template

from distributed.utils import recursive_to_dict

from . import preloading, profile
from . import versions as version_module
from .active_memory_manager import ActiveMemoryManagerExtension
from .batched import BatchedSend
from .comm import (
Comm,
get_address_host,
normalize_address,
resolve_address,
Expand Down Expand Up @@ -1726,6 +1729,30 @@ def get_nbytes_deps(self):
nbytes += ts.get_nbytes()
return nbytes

@ccall
def _to_dict(self, *, exclude: Container[str] = None):
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.

Parameters
----------
exclude:
A list of attributes which must not be present in the output.

See also
--------
Client.dump_cluster_state
"""

if not exclude:
exclude = set()
members = inspect.getmembers(self)
return recursive_to_dict(
{k: v for k, v in members if k not in exclude and not callable(v)},
exclude=exclude,
)


class _StateLegacyMapping(Mapping):
"""
Expand Down Expand Up @@ -3947,6 +3974,39 @@ def identity(self, comm=None):
}
return d

def _to_dict(
self, comm: Comm = None, *, exclude: Container[str] = None
) -> "dict[str, Any]":
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.

Parameters
----------
comm:
exclude:
A list of attributes which must not be present in the output.

See also
--------
Server.identity
Client.dump_cluster_state
"""

info = super()._to_dict(exclude=exclude)
extra = {
"transition_log": self.transition_log,
"log": self.log,
"tasks": self.tasks,
"events": self.events,
}
info.update(extra)
extensions = {}
for name, ex in self.extensions.items():
if hasattr(ex, "_to_dict"):
extensions[name] = ex._to_dict()
return recursive_to_dict(info, exclude=exclude)

def get_worker_service_addr(self, worker, service_name, protocol=False):
"""
Get the (host, port) address of the named service on the *worker*.
Expand Down
31 changes: 30 additions & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
from collections import defaultdict, deque
from math import log2
from time import time
from typing import Any, Container

from tlz import topk
from tornado.ioloop import PeriodicCallback
Expand All @@ -12,7 +15,7 @@
from .comm.addressing import get_address_host
from .core import CommClosedError
from .diagnostics.plugin import SchedulerPlugin
from .utils import log_errors
from .utils import log_errors, recursive_to_dict

# Stealing requires multiple network bounces and if successful also task
# submission which may include code serialization. Therefore, be very
Expand Down Expand Up @@ -79,6 +82,32 @@ def __init__(self, scheduler):

self.scheduler.stream_handlers["steal-response"] = self.move_task_confirm

def _to_dict(self, *, exclude: Container[str] = None) -> dict[str, Any]:
"""
A very verbose dictionary representation for debugging purposes.
Not type stable and not inteded for roundtrips.

Parameters
----------
comm:
exclude:
A list of attributes which must not be present in the output.

See also
--------
Client.dump_cluster_state
"""
return recursive_to_dict(
{
"stealable_all": self.stealable_all,
"stealable": self.stealable,
"key_stealable": self.key_stealable,
"in_flight": self.in_flight,
"in_flight_occupancy": self.in_flight_occupancy,
},
exclude=exclude,
)

def log(self, msg):
return self.scheduler.log_event("stealing", msg)

Expand Down
Loading