Skip to content

Commit

Permalink
Support for dask.graph_manipulation (#4965)
Browse files Browse the repository at this point in the history
* Support dask.graph_manipulation

* fix

* What's New

* [test-upstream]
  • Loading branch information
crusaderky committed Mar 5, 2021
1 parent 66acafa commit 37522e9
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 65 deletions.
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ v0.17.1 (unreleased)

New Features
~~~~~~~~~~~~

- Support for `dask.graph_manipulation
<https://docs.dask.org/en/latest/graph_manipulation.html>`_ (requires dask >=2021.3)
By `Guido Imperiale <https://github.com/crusaderky>`_

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,15 +839,15 @@ def __dask_scheduler__(self):

def __dask_postcompute__(self):
func, args = self._to_temp_dataset().__dask_postcompute__()
return self._dask_finalize, (func, args, self.name)
return self._dask_finalize, (self.name, func) + args

def __dask_postpersist__(self):
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (func, args, self.name)
return self._dask_finalize, (self.name, func) + args

@staticmethod
def _dask_finalize(results, func, args, name):
ds = func(results, *args)
def _dask_finalize(results, name, func, *args, **kwargs):
ds = func(results, *args, **kwargs)
variable = ds._variables.pop(_THIS_ARRAY)
coords = ds._variables
return DataArray(variable, coords, name=name, fastpath=True)
Expand Down
107 changes: 59 additions & 48 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,72 +863,83 @@ def __dask_scheduler__(self):
return da.Array.__dask_scheduler__

def __dask_postcompute__(self):
return self._dask_postcompute, ()

def __dask_postpersist__(self):
return self._dask_postpersist, ()

def _dask_postcompute(self, results: "Iterable[Variable]") -> "Dataset":
import dask

info = [
(k, None) + v.__dask_postcompute__()
if dask.is_dask_collection(v)
else (k, v, None, None)
for k, v in self._variables.items()
]
construct_direct_args = (
variables = {}
results_iter = iter(results)

for k, v in self._variables.items():
if dask.is_dask_collection(v):
rebuild, args = v.__dask_postcompute__()
v = rebuild(next(results_iter), *args)
variables[k] = v

return Dataset._construct_direct(
variables,
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
return self._dask_postcompute, (info, construct_direct_args)

def __dask_postpersist__(self):
import dask
def _dask_postpersist(
self, dsk: Mapping, *, rename: Mapping[str, str] = None
) -> "Dataset":
from dask import is_dask_collection
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import cull

info = [
(k, None, v.__dask_keys__()) + v.__dask_postpersist__()
if dask.is_dask_collection(v)
else (k, v, None, None, None)
for k, v in self._variables.items()
]
construct_direct_args = (
variables = {}

for k, v in self._variables.items():
if not is_dask_collection(v):
variables[k] = v
continue

if isinstance(dsk, HighLevelGraph):
# dask >= 2021.3
# __dask_postpersist__() was called by dask.highlevelgraph.
# Don't use dsk.cull(), as we need to prevent partial layers:
# https://github.com/dask/dask/issues/7137
layers = v.__dask_layers__()
if rename:
layers = [rename.get(k, k) for k in layers]
dsk2 = dsk.cull_layers(layers)
elif rename: # pragma: nocover
# At the moment of writing, this is only for forward compatibility.
# replace_name_in_key requires dask >= 2021.3.
from dask.base import flatten, replace_name_in_key

keys = [
replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__())
]
dsk2, _ = cull(dsk, keys)
else:
# __dask_postpersist__() was called by dask.optimize or dask.persist
dsk2, _ = cull(dsk, v.__dask_keys__())

rebuild, args = v.__dask_postpersist__()
# rename was added in dask 2021.3
kwargs = {"rename": rename} if rename else {}
variables[k] = rebuild(dsk2, *args, **kwargs)

return Dataset._construct_direct(
variables,
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)
return self._dask_postpersist, (info, construct_direct_args)

@staticmethod
def _dask_postcompute(results, info, construct_direct_args):
variables = {}
results_iter = iter(results)
for k, v, rebuild, rebuild_args in info:
if v is None:
variables[k] = rebuild(next(results_iter), *rebuild_args)
else:
variables[k] = v

final = Dataset._construct_direct(variables, *construct_direct_args)
return final

@staticmethod
def _dask_postpersist(dsk, info, construct_direct_args):
from dask.optimization import cull

variables = {}
# postpersist is called in both dask.optimize and dask.persist
# When persisting, we want to filter out unrelated keys for
# each Variable's task graph.
for k, v, dask_keys, rebuild, rebuild_args in info:
if v is None:
dsk2, _ = cull(dsk, dask_keys)
variables[k] = rebuild(dsk2, *rebuild_args)
else:
variables[k] = v

return Dataset._construct_direct(variables, *construct_direct_args)

def compute(self, **kwargs) -> "Dataset":
"""Manually trigger loading and/or computation of this dataset's data
Expand Down
17 changes: 5 additions & 12 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,22 +531,15 @@ def __dask_scheduler__(self):

def __dask_postcompute__(self):
array_func, array_args = self._data.__dask_postcompute__()
return (
self._dask_finalize,
(array_func, array_args, self._dims, self._attrs, self._encoding),
)
return self._dask_finalize, (array_func,) + array_args

def __dask_postpersist__(self):
array_func, array_args = self._data.__dask_postpersist__()
return (
self._dask_finalize,
(array_func, array_args, self._dims, self._attrs, self._encoding),
)
return self._dask_finalize, (array_func,) + array_args

@staticmethod
def _dask_finalize(results, array_func, array_args, dims, attrs, encoding):
data = array_func(results, *array_args)
return Variable(dims, data, attrs=attrs, encoding=encoding)
def _dask_finalize(self, results, array_func, *args, **kwargs):
data = array_func(results, *args, **kwargs)
return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)

@property
def values(self):
Expand Down
35 changes: 35 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,3 +1599,38 @@ def test_optimize():
arr = xr.DataArray(a).chunk(5)
(arr2,) = dask.optimize(arr)
arr2.compute()


# The graph_manipulation module is in dask since 2021.2 but it became usable with
# xarray only since 2021.3
@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module")
def test_graph_manipulation():
"""dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder
function returned by __dask_postperist__; also, the dsk passed to the rebuilder is
a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict.
"""
import dask.graph_manipulation as gm

v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2
da = DataArray(v)
ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])})

v2, da2, ds2 = gm.clone(v, da, ds)

assert_equal(v2, v)
assert_equal(da2, da)
assert_equal(ds2, ds)

for a, b in ((v, v2), (da, da2), (ds, ds2)):
assert a.__dask_layers__() != b.__dask_layers__()
assert len(a.__dask_layers__()) == len(b.__dask_layers__())
assert a.__dask_graph__().keys() != b.__dask_graph__().keys()
assert len(a.__dask_graph__()) == len(b.__dask_graph__())
assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys()
assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers)

# Above we performed a slice operation; adding the two slices back together creates
# a diamond-shaped dependency graph, which in turn will trigger a collision in layer
# names if we were to use HighLevelGraph.cull() instead of
# HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__().
assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2)

0 comments on commit 37522e9

Please sign in to comment.