From 37522e991a32ee3c0ad1a5ff8afe8e3eb1885550 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 5 Mar 2021 09:24:14 +0000 Subject: [PATCH] Support for dask.graph_manipulation (#4965) * Support dask.graph_manipulation * fix * What's New * [test-upstream] --- doc/whats-new.rst | 4 +- xarray/core/dataarray.py | 8 +-- xarray/core/dataset.py | 107 +++++++++++++++++++++----------------- xarray/core/variable.py | 17 ++---- xarray/tests/test_dask.py | 35 +++++++++++++ 5 files changed, 106 insertions(+), 65 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7c3d57f4fe8..9e59fdc5b35 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,7 +22,9 @@ v0.17.1 (unreleased) New Features ~~~~~~~~~~~~ - +- Support for `dask.graph_manipulation + `_ (requires dask >=2021.3) + By `Guido Imperiale `_ Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e6209b0604b..dd871eb21bc 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -839,15 +839,15 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): func, args = self._to_temp_dataset().__dask_postcompute__() - return self._dask_finalize, (func, args, self.name) + return self._dask_finalize, (self.name, func) + args def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() - return self._dask_finalize, (func, args, self.name) + return self._dask_finalize, (self.name, func) + args @staticmethod - def _dask_finalize(results, func, args, name): - ds = func(results, *args) + def _dask_finalize(results, name, func, *args, **kwargs): + ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables return DataArray(variable, coords, name=name, fastpath=True) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cbc30dddda9..a4001c747bb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -863,15 +863,25 @@ def __dask_scheduler__(self): return da.Array.__dask_scheduler__ def __dask_postcompute__(self): + return self._dask_postcompute, () + + def __dask_postpersist__(self): + return self._dask_postpersist, () + + def _dask_postcompute(self, results: "Iterable[Variable]") -> "Dataset": import dask - info = [ - (k, None) + v.__dask_postcompute__() - if dask.is_dask_collection(v) - else (k, v, None, None) - for k, v in self._variables.items() - ] - construct_direct_args = ( + variables = {} + results_iter = iter(results) + + for k, v in self._variables.items(): + if dask.is_dask_collection(v): + rebuild, args = v.__dask_postcompute__() + v = rebuild(next(results_iter), *args) + variables[k] = v + + return Dataset._construct_direct( + variables, self._coord_names, self._dims, self._attrs, @@ -879,18 +889,50 @@ def __dask_postcompute__(self): self._encoding, self._close, ) - return self._dask_postcompute, (info, construct_direct_args) - def __dask_postpersist__(self): - import dask + def _dask_postpersist( + self, dsk: Mapping, *, rename: Mapping[str, str] = None + ) -> "Dataset": + from dask import is_dask_collection + from dask.highlevelgraph import HighLevelGraph + from dask.optimization import cull - info = [ - (k, None, v.__dask_keys__()) + v.__dask_postpersist__() - if dask.is_dask_collection(v) - else (k, v, None, None, None) - for k, v in self._variables.items() - ] - construct_direct_args = ( + variables = {} + + for k, v in self._variables.items(): + if not is_dask_collection(v): + variables[k] = v + continue + + if isinstance(dsk, HighLevelGraph): + # dask >= 2021.3 + # __dask_postpersist__() was called by dask.highlevelgraph. + # Don't use dsk.cull(), as we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + layers = v.__dask_layers__() + if rename: + layers = [rename.get(k, k) for k in layers] + dsk2 = dsk.cull_layers(layers) + elif rename: # pragma: nocover + # At the moment of writing, this is only for forward compatibility. + # replace_name_in_key requires dask >= 2021.3. + from dask.base import flatten, replace_name_in_key + + keys = [ + replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__()) + ] + dsk2, _ = cull(dsk, keys) + else: + # __dask_postpersist__() was called by dask.optimize or dask.persist + dsk2, _ = cull(dsk, v.__dask_keys__()) + + rebuild, args = v.__dask_postpersist__() + # rename was added in dask 2021.3 + kwargs = {"rename": rename} if rename else {} + variables[k] = rebuild(dsk2, *args, **kwargs) + + return Dataset._construct_direct( + variables, self._coord_names, self._dims, self._attrs, @@ -898,37 +940,6 @@ def __dask_postpersist__(self): self._encoding, self._close, ) - return self._dask_postpersist, (info, construct_direct_args) - - @staticmethod - def _dask_postcompute(results, info, construct_direct_args): - variables = {} - results_iter = iter(results) - for k, v, rebuild, rebuild_args in info: - if v is None: - variables[k] = rebuild(next(results_iter), *rebuild_args) - else: - variables[k] = v - - final = Dataset._construct_direct(variables, *construct_direct_args) - return final - - @staticmethod - def _dask_postpersist(dsk, info, construct_direct_args): - from dask.optimization import cull - - variables = {} - # postpersist is called in both dask.optimize and dask.persist - # When persisting, we want to filter out unrelated keys for - # each Variable's task graph. - for k, v, dask_keys, rebuild, rebuild_args in info: - if v is None: - dsk2, _ = cull(dsk, dask_keys) - variables[k] = rebuild(dsk2, *rebuild_args) - else: - variables[k] = v - - return Dataset._construct_direct(variables, *construct_direct_args) def compute(self, **kwargs) -> "Dataset": """Manually trigger loading and/or computation of this dataset's data diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9b70f721689..c59cbf1f3e4 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -531,22 +531,15 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): array_func, array_args = self._data.__dask_postcompute__() - return ( - self._dask_finalize, - (array_func, array_args, self._dims, self._attrs, self._encoding), - ) + return self._dask_finalize, (array_func,) + array_args def __dask_postpersist__(self): array_func, array_args = self._data.__dask_postpersist__() - return ( - self._dask_finalize, - (array_func, array_args, self._dims, self._attrs, self._encoding), - ) + return self._dask_finalize, (array_func,) + array_args - @staticmethod - def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - data = array_func(results, *array_args) - return Variable(dims, data, attrs=attrs, encoding=encoding) + def _dask_finalize(self, results, array_func, *args, **kwargs): + data = array_func(results, *args, **kwargs) + return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @property def values(self): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8220c8b83dc..908a959db45 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1599,3 +1599,38 @@ def test_optimize(): arr = xr.DataArray(a).chunk(5) (arr2,) = dask.optimize(arr) arr2.compute() + + +# The graph_manipulation module is in dask since 2021.2 but it became usable with +# xarray only since 2021.3 +@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module") +def test_graph_manipulation(): + """dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder + function returned by __dask_postperist__; also, the dsk passed to the rebuilder is + a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict. + """ + import dask.graph_manipulation as gm + + v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2 + da = DataArray(v) + ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])}) + + v2, da2, ds2 = gm.clone(v, da, ds) + + assert_equal(v2, v) + assert_equal(da2, da) + assert_equal(ds2, ds) + + for a, b in ((v, v2), (da, da2), (ds, ds2)): + assert a.__dask_layers__() != b.__dask_layers__() + assert len(a.__dask_layers__()) == len(b.__dask_layers__()) + assert a.__dask_graph__().keys() != b.__dask_graph__().keys() + assert len(a.__dask_graph__()) == len(b.__dask_graph__()) + assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys() + assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers) + + # Above we performed a slice operation; adding the two slices back together creates + # a diamond-shaped dependency graph, which in turn will trigger a collision in layer + # names if we were to use HighLevelGraph.cull() instead of + # HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__(). + assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2)