-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement dask-specific methods #196
Changes from all commits
68eabcf
c8a16c5
71154da
721fa8d
29acc55
c9c83d9
a2145dd
cf84ed1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1345,3 +1345,223 @@ def to_zarr( | |
|
||
def plot(self): | ||
raise NotImplementedError | ||
|
||
def load(self: DataTree, **kwargs) -> DataTree: | ||
"""Manually trigger loading of the data referenced by this collection. | ||
|
||
|
||
End-users generally shouldn't need to call this method directly, since | ||
most operations should dispatch to the underlying xarray objects which | ||
this collection contains. There may be use cases where a user wants to | ||
eagerly load data from disk into memory. | ||
|
||
Parameters | ||
---------- | ||
**kwargs : dict | ||
Additional keyword arguments passed on to ``dask.compute``. | ||
|
||
See Also | ||
-------- | ||
dask.compute | ||
""" | ||
new_datatree_dict = {node.path: node.ds.load(**kwargs) for node in self.subtree} | ||
return DataTree.from_dict(new_datatree_dict) | ||
|
||
def compute(self: DataTree, **kwargs) -> DataTree: | ||
"""Manually trigger loading of the data referenced by this collection | ||
and return a new DataTree. The original is left unaltered. | ||
|
||
|
||
End-users generally shouldn't need to call this method directly, since | ||
most operations should dispatch to the underlying xarray objects which | ||
this collection contains. There may be use cases where a user needs to | ||
work with many file objects on disk. | ||
|
||
Parameters | ||
---------- | ||
**kwargs : dict | ||
Additional keyword arguments passed on to ``dask.compute``. | ||
|
||
See Also | ||
-------- | ||
dask.compute | ||
""" | ||
new = self.copy(deep=False) | ||
return new.load(**kwargs) | ||
|
||
def persist(self: DataTree, **kwargs) -> DataTree: | ||
"""Trigger computation in constituent dask arrays. | ||
|
||
|
||
Force any data contained in dask arrays to be loaded into memory, where | ||
possible, but keep the data as dask arrays. This is useful when | ||
operating on data with a distributed cluster; if you're using a single | ||
machine with a single pool of memory, consider using ``.compute()`` | ||
instead. | ||
|
||
Parameters | ||
---------- | ||
**kwargs : dict | ||
Additional keyword arguments passed to ``dask.persist``. | ||
|
||
See Also | ||
-------- | ||
dask.persist | ||
""" | ||
new_datatree_dict = { | ||
node.path: node.ds.persist(**kwargs) for node in self.subtree | ||
} | ||
return DataTree.from_dict(new_datatree_dict) | ||
Comment on lines
+1411
to
+1414
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment on |
||
|
||
def __dask_tokenize__(self): | ||
from dask.base import normalize_token | ||
|
||
# This method should return a value fully representative of the object | ||
# here. ``xarray.Dataset`` implements a method that accomplishes this, | ||
# and ``DataTree`` is just fundamentally defining relationships between | ||
# these ``Dataset``s. So here we re-use the ``Dataset`` tokenization and | ||
# incorporate the ancestry as an additional component (encoded in the | ||
# incorporate the ancestry as an additional component (encoded in the | ||
# names of the datasets). | ||
|
||
ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree} | ||
ds_tokens = {node.path: node.ds.__dask_tokenize__() for node in self.subtree} | ||
|
||
return normalize_token((type(self), ds_tokens)) | ||
|
||
return normalize_token((type(self), ds_tokens)) | ||
Comment on lines
+1427
to
+1432
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unintentional repetition of lines? The double |
||
|
||
def __dask_graph__(self): | ||
graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree} | ||
graphs = {node.path: node.ds.__dask_graph__() for node in self.subtree} | ||
Comment on lines
+1435
to
+1436
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More unintentional repetition? |
||
graphs = {k: v for k, v in graphs.items() if v is not None} | ||
|
||
if not graphs: | ||
return None | ||
else: | ||
try: | ||
from dask.highlevelgraph import HighLevelGraph | ||
|
||
return HighLevelGraph.merge(*graphs.values()) | ||
except ImportError: | ||
from dask import sharedict | ||
|
||
return sharedict.merge(*graphs.values()) | ||
|
||
def __dask_keys__(self): | ||
return [node.ds.__dask_keys__() for node in self.subtree] | ||
return [node.ds.__dask_keys__() for node in self.subtree] | ||
Comment on lines
+1452
to
+1453
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here |
||
|
||
def __dask_layers__(self): | ||
all_keys = self.__dask_keys__() | ||
return sum((all_keys), ()) | ||
|
||
@property | ||
def __dask_optimize__(self): | ||
import dask.array as da | ||
|
||
return da.Array.__dask_optimize__ | ||
|
||
@property | ||
def __dask_scheduler__(self): | ||
import dask.array as da | ||
|
||
return da.Array.__dask_scheduler__ | ||
|
||
def __dask_postcompute__(self): | ||
return self._dask_postcompute, () | ||
|
||
def _dask_postcompute(self: DataTree, results: Iterable[DatasetView]) -> DataTree: | ||
from dask import is_dask_collection | ||
|
||
datatree_nodes = {} | ||
results_iter = iter(results) | ||
|
||
for node in self.subtree: | ||
if is_dask_collection(node.ds): | ||
finalize, args = node.ds.__dask_postcompute__() | ||
# darothen: Are we sure that results_iter is ordered the same as | ||
# self.subtree? | ||
# self.subtree? | ||
Comment on lines
+1483
to
+1485
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does the iterable of |
||
ds = finalize(next(results_iter), *args) | ||
else: | ||
ds = node.ds | ||
datatree_nodes[node.path] = ds | ||
|
||
# We use this to avoid validation at time of object creation | ||
new_root = datatree_nodes[self.path] | ||
return type(self)._construct_direct( | ||
new_root.ds._variables, | ||
new_root.ds._coord_names, | ||
new_root.ds._dims, | ||
new_root.ds._attrs, | ||
new_root.ds._indexes, | ||
new_root.ds._encoding, | ||
new_root._name, | ||
new_root._parent, | ||
new_root._children, | ||
new_root._close, | ||
) | ||
|
||
def __dask_postpersist__(self): | ||
return self._dask_postpersist, () | ||
|
||
def _dask_postpersist( | ||
self: DataTree, dsk: Mapping, *, rename: Mapping[str, str] | None = None | ||
) -> DataTree: | ||
from dask import is_dask_collection | ||
from dask.highlevelgraph import HighLevelGraph | ||
from dask.optimization import cull | ||
|
||
datatree_nodes = {} | ||
|
||
for node in self.subtree: | ||
if not is_dask_collection(node): | ||
datatree_nodes[node.path] = node.ds | ||
continue | ||
|
||
if isinstance(dsk, HighLevelGraph): | ||
# NOTE(darothen): Implementation based on xarray.Dataset._dask_postpersist(), | ||
# so we preserve the implementation note for future refinement | ||
# TODO: Pin minimum dask version and ensure that can remove this | ||
# note. | ||
# dask >= 2021.3 | ||
# __dask_postpersist__() was called by dask.highlevelgraph. | ||
# Don't use dsk.cull(), as we need to prevent partial layers: | ||
# https://github.com/dask/dask/issues/7137 | ||
layers = node.__dask_layers__() | ||
if rename: | ||
layers = [rename.get(k, k) for k in layers] | ||
dsk2 = dsk.cull_layers(layers) | ||
elif rename: # pragma: nocover | ||
# NOTE(darothen): Similar to above we preserve the implementation | ||
# note. | ||
# replace_name_in_key requires dask >= 2021.3. | ||
from dask.base import flatten, replace_name_in_key | ||
|
||
keys = [ | ||
replace_name_in_key(k, rename) | ||
for k in flatten(node.__dask_keys__()) | ||
] | ||
dsk2, _ = cull(dsk, keys) | ||
else: | ||
# __dask_postpersist__() was called by dask.{optimize,persist} | ||
dsk2, _ = cull(dsk, node.__dask_keys__()) | ||
|
||
finalize, args = node.__dask_postpersist__() | ||
kwargs = {"rename": rename} if rename else {} | ||
datatree_nodes[node.path] = finalize(dsk2, *args, **kwargs) | ||
|
||
new_root = datatree_nodes[self.path] | ||
return type(self)._construct_direct( | ||
new_root.ds._variables, | ||
new_root.ds._coord_names, | ||
new_root.ds._dims, | ||
new_root.ds._attrs, | ||
new_root.ds._indexes, | ||
new_root.ds._encoding, | ||
new_root._name, | ||
new_root._parent, | ||
new_root._children, | ||
new_root._close, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this will have the behavior you intend:
DataTree.from_dict
will construct a completely new tree object, and then you are inserting whatever you get when you callDatasetView.load()
. It's not alteringself
in-place.Also I think it would be worth double-checking that
DatasetView.load()
does what you expect too with regard to new objects / copying - I never really thought about that case when I wroteDatasetView
.If you want to return the same tree but with all the data loaded I think you need to alter the current tree in-place instead of creating a new one, i.e. something like
though this might not fail gracefully...