From 68eabcf24203e2cb3e69ba9c042ce247cf99a578 Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 14:54:32 -0700 Subject: [PATCH 1/8] Adds basic implementation of load, compute, and persist. --- datatree/datatree.py | 85 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/datatree/datatree.py b/datatree/datatree.py index 9a416d8f..b5fde591 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -18,6 +18,7 @@ Optional, Set, Tuple, + TypeVar, Union, overload, ) @@ -62,6 +63,7 @@ T_Path = Union[str, NodePath] +T_DataTree = TypeVar("T_DataTree", bound="DataTree") def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: @@ -1345,3 +1347,86 @@ def to_zarr( def plot(self): raise NotImplementedError + + + def load(self: T_DataTree, **kwargs) -> T_DataTree: + """Manually trigger loading of the data referenced by this collection. + + End-users generally shouldn't need to call this method directly, since + most operations should dispatch to the underlying xarray objects which + this collection contains. There may be use cases where a user wants to + eagerly load data from disk into memory. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + dask.compute + """ + # new_tree = self._copy_node(deep=deep) + # for node in self.subtree: + # new_tree[node.path] = op(node) + + # return new_tree + + # d = {node.path: op(node) for node in self.subtree} + # return DataTree.from_dict(d, name=self.root.name) + + new_datatree_dict = { + node.path: node.ds.load(**kwargs) + for node in self.subtree + } + return DataTree.from_dict(new_datatree_dict) + + + def compute(self: T_DataTree, **kwargs) -> T_DataTree: + """Manually trigger loading of the data referenced by this collection + and return a new DataTree. The original is left unaltered. + + End-users generally shouldn't need to call this method directly, since + most operations should dispatch to the underlying xarray objects which + this collection contains. There may be use cases where a user needs to + work with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + dask.compute + """ + new = self.copy(deep=False) + return new.load(**kwargs) + + + def persist(self: T_DataTree, **kwargs) -> T_DataTree: + """Trigger computation in constituent dask arrays. + + Force any data contained in dask arrays to be loaded into memory, where + possible, but keep the data as dask arrays. This is useful when + operating on data with a distributed cluster; if you're using a single + machine with a single pool of memory, consider using ``.compute()`` + instead. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed to ``dask.persist``. + + See Also + -------- + dask.persist + """ + new_datatree_dict = { + node.path: node.ds.persist(**kwargs) + for node in self.subtree + } + return DataTree.from_dict(new_datatree_dict) + + + \ No newline at end of file From c8a16c594e65b53cd0374cc4ec1ca705682d652a Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 14:57:12 -0700 Subject: [PATCH 2/8] Adds stubs for dask collections API methods. --- datatree/datatree.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index b5fde591..2f268ef8 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1348,7 +1348,6 @@ def to_zarr( def plot(self): raise NotImplementedError - def load(self: T_DataTree, **kwargs) -> T_DataTree: """Manually trigger loading of the data referenced by this collection. @@ -1381,7 +1380,6 @@ def load(self: T_DataTree, **kwargs) -> T_DataTree: } return DataTree.from_dict(new_datatree_dict) - def compute(self: T_DataTree, **kwargs) -> T_DataTree: """Manually trigger loading of the data referenced by this collection and return a new DataTree. The original is left unaltered. @@ -1403,7 +1401,6 @@ def compute(self: T_DataTree, **kwargs) -> T_DataTree: new = self.copy(deep=False) return new.load(**kwargs) - def persist(self: T_DataTree, **kwargs) -> T_DataTree: """Trigger computation in constituent dask arrays. @@ -1428,5 +1425,32 @@ def persist(self: T_DataTree, **kwargs) -> T_DataTree: } return DataTree.from_dict(new_datatree_dict) + def __dask_tokenize__(self): + raise NotImplementedError + + def __dask_graph__(self): + raise NotImplementedError + + def __dask_keys__(self): + raise NotImplementedError + + def __dask_layers__(self): + raise NotImplementedError + + @property + def __dask_optimize__(self): + raise NotImplementedError + + @property + def __dask_scheduler__(self): + raise NotImplementedError + + def __dask_postcompute__(self): + raise NotImplementedError + + def __dask_postpersist__(self): + raise NotImplementedError + + \ No newline at end of file From 71154daba2e14ef006ba4934f16f7931fe703244 Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 15:19:25 -0700 Subject: [PATCH 3/8] Implements dask dunder methods except for post* --- datatree/datatree.py | 52 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 2f268ef8..8710e087 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1426,24 +1426,64 @@ def persist(self: T_DataTree, **kwargs) -> T_DataTree: return DataTree.from_dict(new_datatree_dict) def __dask_tokenize__(self): - raise NotImplementedError + from dask.base import normalize_token + + # This method should return a value fully representative of the object + # here. ``xarray.Dataset`` implements a method that accomplishes this, + # and ``DataTree`` is just fundamentally defining relationships between + # these ``Dataset``s. So here we re-use the ``Dataset`` tokenization and + # incorporate the ancestry as an additional component (encoded in the + # names of the datasets). + + ds_tokens = { + node.path: node.ds.__dask_tokenize__() + for node in self.subtree + } + return normalize_token( + (type(self), ds_tokens) + ) + def __dask_graph__(self): - raise NotImplementedError + graphs = { + node.path: node.ds.__dask_graph__() + for node in self.subtree + } + graphs = {k: v for k, v in graphs.items() if v is not None} + + if not graphs: + return None + else: + try: + from dask.highlevelgraph import HighLevelGraph + + return HighLevelGraph.merge(*graphs.values()) + except ImportError: + from dask import sharedict + + return sharedict.merge(*graphs.values()) def __dask_keys__(self): - raise NotImplementedError + return [ + node.ds.__dask_keys__() + for node in self.subtree + ] def __dask_layers__(self): - raise NotImplementedError + all_keys = self.__dask_keys__() + return sum((all_keys), ()) @property def __dask_optimize__(self): - raise NotImplementedError + import dask.array as da + + return da.Array.__dask_optimize__ @property def __dask_scheduler__(self): - raise NotImplementedError + import dask.array as da + + return da.Array.__dask_scheduler__ def __dask_postcompute__(self): raise NotImplementedError From 721fa8d9cd80751098bf7db7e9940d97dedb8021 Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 16:06:02 -0700 Subject: [PATCH 4/8] Implements the postcompute/persist dask ops. --- datatree/datatree.py | 95 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 4 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 8710e087..6b99448e 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1486,11 +1486,98 @@ def __dask_scheduler__(self): return da.Array.__dask_scheduler__ def __dask_postcompute__(self): - raise NotImplementedError + return self._dask_postcompute, () - def __dask_postpersist__(self): - raise NotImplementedError + def _dask_postcompute( + self: T_DataTree, results: Iterable[DatasetView]) -> T_DataTree: + from dask import is_dask_collection + + datatree_nodes = {} + results_iter = iter(results) + for node in self.subtree: + if is_dask_collection(node.ds): + finalize, args = node.ds.__dask_postcompute__() + # darothen: Are we sure that results_iter is ordered the same as + # self.subtree? + ds = finalize(next(results_iter), *args) + else: + ds = node.ds + datatree_nodes[node.path] = ds + + # We use this to avoid validation at time of object creation + new_root = datatree_nodes[self.path] + return type(self)._construct_direct( + new_root.ds._variables, + new_root.ds._coord_names, + new_root.ds._dims, + new_root.ds._attrs, + new_root.ds._indexes, + new_root.ds._encoding, + new_root._name, + new_root._parent, + new_root._children, + new_root._close + ) + + def __dask_postpersist__(self): + return self._dask_postpersist, () + + def _dask_postpersist( + self: T_DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> T_DataTree: + from dask import is_dask_collection + from dask.highlevelgraph import HighLevelGraph + from dask.optimization import cull + datatree_nodes = {} - \ No newline at end of file + for node in self.subtree: + if not is_dask_collection(node): + datatree_nodes[node.path] = node.ds + continue + + if isinstance(dsk, HighLevelGraph): + # NOTE(darothen): Implementation based on xarray.Dataset._dask_postpersist(), + # so we preserve the implementation note for future refinement + # TODO: Pin minimum dask version and ensure that can remove this + # note. + # dask >= 2021.3 + # __dask_postpersist__() was called by dask.highlevelgraph. + # Don't use dsk.cull(), as we need to prevent partial layers: + # https://github.com/dask/dask/issues/7137 + layers = node.__dask_layers__() + if rename: + layers = [rename.get(k, k) for k in layers] + dsk2 = dsk.cull_layers(layers) + elif rename: # pragma: nocover + # NOTE(darothen): Similar to above we preserve the implementation + # note. + # replace_name_in_key requires dask >= 2021.3. + from dask.base import flatten, replace_name_in_key + + keys = [ + replace_name_in_key(k, rename) for k in flatten(node.__dask_keys__()) + ] + dsk2, _ = cull(dsk, keys) + else: + # __dask_postpersist__() was called by dask.{optimize,persist} + dsk2, _ = cull(dsk, node.__dask_keys__()) + + finalize, args = node.__dask_postpersist__() + kwargs = {"rename": rename} if rename else {} + datatree_nodes[node.path] = finalize(dsk2, *args, **kwargs) + + new_root = datatree_nodes[self.path] + return type(self)._construct_direct( + new_root.ds._variables, + new_root.ds._coord_names, + new_root.ds._dims, + new_root.ds._attrs, + new_root.ds._indexes, + new_root.ds._encoding, + new_root._name, + new_root._parent, + new_root._children, + new_root._close + ) \ No newline at end of file From 29acc55216a2e9fd2d0209690bdc3d9c6779e14a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jan 2023 23:11:37 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- datatree/datatree.py | 85 +++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 49 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index 6b99448e..b67b81e6 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1350,7 +1350,7 @@ def plot(self): def load(self: T_DataTree, **kwargs) -> T_DataTree: """Manually trigger loading of the data referenced by this collection. - + End-users generally shouldn't need to call this method directly, since most operations should dispatch to the underlying xarray objects which this collection contains. There may be use cases where a user wants to @@ -1366,24 +1366,21 @@ def load(self: T_DataTree, **kwargs) -> T_DataTree: dask.compute """ # new_tree = self._copy_node(deep=deep) - # for node in self.subtree: + # for node in self.subtree: # new_tree[node.path] = op(node) - + # return new_tree # d = {node.path: op(node) for node in self.subtree} # return DataTree.from_dict(d, name=self.root.name) - new_datatree_dict = { - node.path: node.ds.load(**kwargs) - for node in self.subtree - } + new_datatree_dict = {node.path: node.ds.load(**kwargs) for node in self.subtree} return DataTree.from_dict(new_datatree_dict) def compute(self: T_DataTree, **kwargs) -> T_DataTree: """Manually trigger loading of the data referenced by this collection and return a new DataTree. The original is left unaltered. - + End-users generally shouldn't need to call this method directly, since most operations should dispatch to the underlying xarray objects which this collection contains. There may be use cases where a user needs to @@ -1403,7 +1400,7 @@ def compute(self: T_DataTree, **kwargs) -> T_DataTree: def persist(self: T_DataTree, **kwargs) -> T_DataTree: """Trigger computation in constituent dask arrays. - + Force any data contained in dask arrays to be loaded into memory, where possible, but keep the data as dask arrays. This is useful when operating on data with a distributed cluster; if you're using a single @@ -1420,8 +1417,7 @@ def persist(self: T_DataTree, **kwargs) -> T_DataTree: dask.persist """ new_datatree_dict = { - node.path: node.ds.persist(**kwargs) - for node in self.subtree + node.path: node.ds.persist(**kwargs) for node in self.subtree } return DataTree.from_dict(new_datatree_dict) @@ -1432,23 +1428,15 @@ def __dask_tokenize__(self): # here. ``xarray.Dataset`` implements a method that accomplishes this, # and ``DataTree`` is just fundamentally defining relationships between # these ``Dataset``s. So here we re-use the ``Dataset`` tokenization and - # incorporate the ancestry as an additional component (encoded in the + # incorporate the ancestry as an additional component (encoded in the # names of the datasets). - ds_tokens = { - node.path: node.ds.__dask_tokenize__() - for node in self.subtree - } + ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree} + + return normalize_token((type(self), ds_tokens)) - return normalize_token( - (type(self), ds_tokens) - ) - def __dask_graph__(self): - graphs = { - node.path: node.ds.__dask_graph__() - for node in self.subtree - } + graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree} graphs = {k: v for k, v in graphs.items() if v is not None} if not graphs: @@ -1464,10 +1452,7 @@ def __dask_graph__(self): return sharedict.merge(*graphs.values()) def __dask_keys__(self): - return [ - node.ds.__dask_keys__() - for node in self.subtree - ] + return [node.ds.__dask_keys__() for node in self.subtree] def __dask_layers__(self): all_keys = self.__dask_keys__() @@ -1489,9 +1474,10 @@ def __dask_postcompute__(self): return self._dask_postcompute, () def _dask_postcompute( - self: T_DataTree, results: Iterable[DatasetView]) -> T_DataTree: + self: T_DataTree, results: Iterable[DatasetView] + ) -> T_DataTree: from dask import is_dask_collection - + datatree_nodes = {} results_iter = iter(results) @@ -1499,12 +1485,12 @@ def _dask_postcompute( if is_dask_collection(node.ds): finalize, args = node.ds.__dask_postcompute__() # darothen: Are we sure that results_iter is ordered the same as - # self.subtree? + # self.subtree? ds = finalize(next(results_iter), *args) else: ds = node.ds datatree_nodes[node.path] = ds - + # We use this to avoid validation at time of object creation new_root = datatree_nodes[self.path] return type(self)._construct_direct( @@ -1517,12 +1503,12 @@ def _dask_postcompute( new_root._name, new_root._parent, new_root._children, - new_root._close + new_root._close, ) def __dask_postpersist__(self): return self._dask_postpersist, () - + def _dask_postpersist( self: T_DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None ) -> T_DataTree: @@ -1536,7 +1522,7 @@ def _dask_postpersist( if not is_dask_collection(node): datatree_nodes[node.path] = node.ds continue - + if isinstance(dsk, HighLevelGraph): # NOTE(darothen): Implementation based on xarray.Dataset._dask_postpersist(), # so we preserve the implementation note for future refinement @@ -1550,34 +1536,35 @@ def _dask_postpersist( if rename: layers = [rename.get(k, k) for k in layers] dsk2 = dsk.cull_layers(layers) - elif rename: # pragma: nocover + elif rename: # pragma: nocover # NOTE(darothen): Similar to above we preserve the implementation # note. # replace_name_in_key requires dask >= 2021.3. from dask.base import flatten, replace_name_in_key keys = [ - replace_name_in_key(k, rename) for k in flatten(node.__dask_keys__()) + replace_name_in_key(k, rename) + for k in flatten(node.__dask_keys__()) ] dsk2, _ = cull(dsk, keys) else: # __dask_postpersist__() was called by dask.{optimize,persist} dsk2, _ = cull(dsk, node.__dask_keys__()) - + finalize, args = node.__dask_postpersist__() kwargs = {"rename": rename} if rename else {} datatree_nodes[node.path] = finalize(dsk2, *args, **kwargs) new_root = datatree_nodes[self.path] return type(self)._construct_direct( - new_root.ds._variables, - new_root.ds._coord_names, - new_root.ds._dims, - new_root.ds._attrs, - new_root.ds._indexes, - new_root.ds._encoding, - new_root._name, - new_root._parent, - new_root._children, - new_root._close - ) \ No newline at end of file + new_root.ds._variables, + new_root.ds._coord_names, + new_root.ds._dims, + new_root.ds._attrs, + new_root.ds._indexes, + new_root.ds._encoding, + new_root._name, + new_root._parent, + new_root._children, + new_root._close, + ) From c9c83d9a2bcfffa9f79a8fa6d79e38278fb18002 Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 16:11:38 -0700 Subject: [PATCH 6/8] Removes stub / sample code. --- datatree/datatree.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index b67b81e6..d43b3a92 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1365,16 +1365,10 @@ def load(self: T_DataTree, **kwargs) -> T_DataTree: -------- dask.compute """ - # new_tree = self._copy_node(deep=deep) - # for node in self.subtree: - # new_tree[node.path] = op(node) - - # return new_tree - - # d = {node.path: op(node) for node in self.subtree} - # return DataTree.from_dict(d, name=self.root.name) - - new_datatree_dict = {node.path: node.ds.load(**kwargs) for node in self.subtree} + new_datatree_dict = { + node.path: node.ds.load(**kwargs) + for node in self.subtree + } return DataTree.from_dict(new_datatree_dict) def compute(self: T_DataTree, **kwargs) -> T_DataTree: From a2145dd76614deb3964081e8c98f87695eb95791 Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 16:15:46 -0700 Subject: [PATCH 7/8] Applies fixes from pre-commit --- datatree/datatree.py | 53 ++++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index d43b3a92..ef2d4024 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -18,7 +18,6 @@ Optional, Set, Tuple, - TypeVar, Union, overload, ) @@ -63,7 +62,6 @@ T_Path = Union[str, NodePath] -T_DataTree = TypeVar("T_DataTree", bound="DataTree") def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: @@ -1348,9 +1346,10 @@ def to_zarr( def plot(self): raise NotImplementedError - def load(self: T_DataTree, **kwargs) -> T_DataTree: + def load(self: DataTree, **kwargs) -> DataTree: """Manually trigger loading of the data referenced by this collection. + End-users generally shouldn't need to call this method directly, since most operations should dispatch to the underlying xarray objects which this collection contains. There may be use cases where a user wants to @@ -1365,16 +1364,14 @@ def load(self: T_DataTree, **kwargs) -> T_DataTree: -------- dask.compute """ - new_datatree_dict = { - node.path: node.ds.load(**kwargs) - for node in self.subtree - } + new_datatree_dict = {node.path: node.ds.load(**kwargs) for node in self.subtree} return DataTree.from_dict(new_datatree_dict) - def compute(self: T_DataTree, **kwargs) -> T_DataTree: + def compute(self: DataTree, **kwargs) -> DataTree: """Manually trigger loading of the data referenced by this collection and return a new DataTree. The original is left unaltered. + End-users generally shouldn't need to call this method directly, since most operations should dispatch to the underlying xarray objects which this collection contains. There may be use cases where a user needs to @@ -1392,9 +1389,10 @@ def compute(self: T_DataTree, **kwargs) -> T_DataTree: new = self.copy(deep=False) return new.load(**kwargs) - def persist(self: T_DataTree, **kwargs) -> T_DataTree: + def persist(self: DataTree, **kwargs) -> DataTree: """Trigger computation in constituent dask arrays. + Force any data contained in dask arrays to be loaded into memory, where possible, but keep the data as dask arrays. This is useful when operating on data with a distributed cluster; if you're using a single @@ -1412,6 +1410,7 @@ def persist(self: T_DataTree, **kwargs) -> T_DataTree: """ new_datatree_dict = { node.path: node.ds.persist(**kwargs) for node in self.subtree + node.path: node.ds.persist(**kwargs) for node in self.subtree } return DataTree.from_dict(new_datatree_dict) @@ -1423,13 +1422,18 @@ def __dask_tokenize__(self): # and ``DataTree`` is just fundamentally defining relationships between # these ``Dataset``s. So here we re-use the ``Dataset`` tokenization and # incorporate the ancestry as an additional component (encoded in the + # incorporate the ancestry as an additional component (encoded in the # names of the datasets). ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree} + ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree} + + return normalize_token((type(self), ds_tokens)) return normalize_token((type(self), ds_tokens)) def __dask_graph__(self): + graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree} graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree} graphs = {k: v for k, v in graphs.items() if v is not None} @@ -1447,6 +1451,7 @@ def __dask_graph__(self): def __dask_keys__(self): return [node.ds.__dask_keys__() for node in self.subtree] + return [node.ds.__dask_keys__() for node in self.subtree] def __dask_layers__(self): all_keys = self.__dask_keys__() @@ -1467,11 +1472,10 @@ def __dask_scheduler__(self): def __dask_postcompute__(self): return self._dask_postcompute, () - def _dask_postcompute( - self: T_DataTree, results: Iterable[DatasetView] - ) -> T_DataTree: + def _dask_postcompute(self: DataTree, results: Iterable[DatasetView]) -> DataTree: from dask import is_dask_collection + datatree_nodes = {} results_iter = iter(results) @@ -1480,11 +1484,13 @@ def _dask_postcompute( finalize, args = node.ds.__dask_postcompute__() # darothen: Are we sure that results_iter is ordered the same as # self.subtree? + # self.subtree? ds = finalize(next(results_iter), *args) else: ds = node.ds datatree_nodes[node.path] = ds + # We use this to avoid validation at time of object creation new_root = datatree_nodes[self.path] return type(self)._construct_direct( @@ -1498,14 +1504,16 @@ def _dask_postcompute( new_root._parent, new_root._children, new_root._close, + new_root._close, ) def __dask_postpersist__(self): return self._dask_postpersist, () + def _dask_postpersist( - self: T_DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None - ) -> T_DataTree: + self: DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> DataTree: from dask import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.optimization import cull @@ -1517,6 +1525,7 @@ def _dask_postpersist( datatree_nodes[node.path] = node.ds continue + if isinstance(dsk, HighLevelGraph): # NOTE(darothen): Implementation based on xarray.Dataset._dask_postpersist(), # so we preserve the implementation note for future refinement @@ -1530,6 +1539,7 @@ def _dask_postpersist( if rename: layers = [rename.get(k, k) for k in layers] dsk2 = dsk.cull_layers(layers) + elif rename: # pragma: nocover elif rename: # pragma: nocover # NOTE(darothen): Similar to above we preserve the implementation # note. @@ -1539,12 +1549,15 @@ def _dask_postpersist( keys = [ replace_name_in_key(k, rename) for k in flatten(node.__dask_keys__()) + replace_name_in_key(k, rename) + for k in flatten(node.__dask_keys__()) ] dsk2, _ = cull(dsk, keys) else: # __dask_postpersist__() was called by dask.{optimize,persist} dsk2, _ = cull(dsk, node.__dask_keys__()) + finalize, args = node.__dask_postpersist__() kwargs = {"rename": rename} if rename else {} datatree_nodes[node.path] = finalize(dsk2, *args, **kwargs) @@ -1562,3 +1575,15 @@ def _dask_postpersist( new_root._children, new_root._close, ) + + new_root.ds._variables, + new_root.ds._coord_names, + new_root.ds._dims, + new_root.ds._attrs, + new_root.ds._indexes, + new_root.ds._encoding, + new_root._name, + new_root._parent, + new_root._children, + new_root._close, + ) From cf84ed1e5c902e0903d6a13e708957dc833d7a1f Mon Sep 17 00:00:00 2001 From: darothen Date: Fri, 13 Jan 2023 16:22:49 -0700 Subject: [PATCH 8/8] Applies additional pre-commit fixes --- datatree/datatree.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/datatree/datatree.py b/datatree/datatree.py index ef2d4024..6dbcfb60 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1410,7 +1410,6 @@ def persist(self: DataTree, **kwargs) -> DataTree: """ new_datatree_dict = { node.path: node.ds.persist(**kwargs) for node in self.subtree - node.path: node.ds.persist(**kwargs) for node in self.subtree } return DataTree.from_dict(new_datatree_dict) @@ -1475,7 +1474,6 @@ def __dask_postcompute__(self): def _dask_postcompute(self: DataTree, results: Iterable[DatasetView]) -> DataTree: from dask import is_dask_collection - datatree_nodes = {} results_iter = iter(results) @@ -1490,7 +1488,6 @@ def _dask_postcompute(self: DataTree, results: Iterable[DatasetView]) -> DataTre ds = node.ds datatree_nodes[node.path] = ds - # We use this to avoid validation at time of object creation new_root = datatree_nodes[self.path] return type(self)._construct_direct( @@ -1504,13 +1501,11 @@ def _dask_postcompute(self: DataTree, results: Iterable[DatasetView]) -> DataTre new_root._parent, new_root._children, new_root._close, - new_root._close, ) def __dask_postpersist__(self): return self._dask_postpersist, () - def _dask_postpersist( self: DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None ) -> DataTree: @@ -1525,7 +1520,6 @@ def _dask_postpersist( datatree_nodes[node.path] = node.ds continue - if isinstance(dsk, HighLevelGraph): # NOTE(darothen): Implementation based on xarray.Dataset._dask_postpersist(), # so we preserve the implementation note for future refinement @@ -1539,7 +1533,6 @@ def _dask_postpersist( if rename: layers = [rename.get(k, k) for k in layers] dsk2 = dsk.cull_layers(layers) - elif rename: # pragma: nocover elif rename: # pragma: nocover # NOTE(darothen): Similar to above we preserve the implementation # note. @@ -1549,15 +1542,12 @@ def _dask_postpersist( keys = [ replace_name_in_key(k, rename) for k in flatten(node.__dask_keys__()) - replace_name_in_key(k, rename) - for k in flatten(node.__dask_keys__()) ] dsk2, _ = cull(dsk, keys) else: # __dask_postpersist__() was called by dask.{optimize,persist} dsk2, _ = cull(dsk, node.__dask_keys__()) - finalize, args = node.__dask_postpersist__() kwargs = {"rename": rename} if rename else {} datatree_nodes[node.path] = finalize(dsk2, *args, **kwargs) @@ -1575,15 +1565,3 @@ def _dask_postpersist( new_root._children, new_root._close, ) - - new_root.ds._variables, - new_root.ds._coord_names, - new_root.ds._dims, - new_root.ds._attrs, - new_root.ds._indexes, - new_root.ds._encoding, - new_root._name, - new_root._parent, - new_root._children, - new_root._close, - )