Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose dataset reduce operations #10

Merged
merged 9 commits into from
Aug 25, 2021
117 changes: 74 additions & 43 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations
import functools
import textwrap
import inspect

from typing import Mapping, Hashable, Union, List, Any, Callable, Iterable, Dict

Expand All @@ -12,6 +11,9 @@
from xarray.core.variable import Variable
from xarray.core.combine import merge
from xarray.core import dtypes, utils
from xarray.core.common import DataWithCoords
from xarray.core.arithmetic import DatasetArithmetic
from xarray.core.ops import REDUCE_METHODS, NAN_REDUCE_METHODS, NAN_CUM_METHODS

from .treenode import TreeNode, PathType, _init_single_treenode

Expand Down Expand Up @@ -46,7 +48,8 @@ def map_over_subtree(func):
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
descendant nodes. The returned tree will have the same structure as the original subtree.

func needs to return a Dataset in order to rebuild the subtree.
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.

Parameters
----------
Expand Down Expand Up @@ -132,7 +135,6 @@ def attrs(self):
else:
raise AttributeError("property is not defined for a node with no data")


@property
def nbytes(self) -> int:
return sum(node.ds.nbytes for node in self.subtree_nodes)
Expand Down Expand Up @@ -203,10 +205,33 @@ def imag(self):

_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to "
"call the method on the Datasets stored in every node of the subtree. "
"See the `map_over_subtree` decorator for more details.", width=117)


def _wrap_then_attach_to_cls(cls_dict, methods_to_expose, wrap_func=None):
"See the `map_over_subtree` function for more details.", width=117)

# TODO equals, broadcast_equals etc.
# TODO do dask-related private methods need to be exposed?
_DATASET_DASK_METHODS_TO_MAP = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
_DATASET_METHODS_TO_MAP = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info',
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
# TODO unsure if these are called by external functions or not?
_DATASET_OPS_TO_MAP = ['_unary_op', '_binary_op', '_inplace_binary_op']
_ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP + _DATASET_OPS_TO_MAP

_DATA_WITH_COORDS_METHODS_TO_MAP = ['squeeze', 'clip', 'assign_coords', 'where', 'close', 'isnull', 'notnull',
'isin', 'astype']

# TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere...
#['__array_ufunc__'] \
_ARITHMETIC_METHODS_TO_WRAP = REDUCE_METHODS + NAN_REDUCE_METHODS + NAN_CUM_METHODS


def _wrap_then_attach_to_cls(target_cls_dict, source_cls, methods_to_set, wrap_func=None):
"""
Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)

Expand All @@ -219,23 +244,32 @@ def method_name(self, *args, **kwargs):

Parameters
----------
cls_dict
The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
definition.
methods_to_expose : Iterable[Tuple[str, callable]]
The method names and definitions supplied as a list of (method_name_string, method) pairs.\
target_cls_dict : MappingProxy
The __dict__ attribute of the class which we want the methods to be added to. (The __dict__ attribute can also
be accessed by calling vars() from within that classes' definition.) This will be updated by this function.
source_cls : class
Class object from which we want to copy methods (and optionally wrap them). Should be the actual class object
(or instance), not just the __dict__.
methods_to_set : Iterable[Tuple[str, callable]]
The method names and definitions supplied as a list of (method_name_string, method) pairs.
This format matches the output of inspect.getmembers().
wrap_func : callable, optional
Function to decorate each method with. Must have the same return type as the method.
"""
for method_name, method in methods_to_expose:
wrapped_method = wrap_func(method) if wrap_func is not None else method
cls_dict[method_name] = wrapped_method

# TODO do we really need this for ops like __add__?
# Add a line to the method's docstring explaining how it's been mapped
method_docstring = method.__doc__
if method_docstring is not None:
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
setattr(cls_dict[method_name], '__doc__', updated_method_docstring)
for method_name in methods_to_set:
orig_method = getattr(source_cls, method_name)
wrapped_method = wrap_func(orig_method) if wrap_func is not None else orig_method
target_cls_dict[method_name] = wrapped_method

if wrap_func is map_over_subtree:
# Add a paragraph to the method's docstring explaining how it's been mapped
orig_method_docstring = orig_method.__doc__
if orig_method_docstring is not None:
if '\n' in orig_method_docstring:
new_method_docstring = orig_method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
else:
new_method_docstring = orig_method_docstring + f"\n\n{_MAPPED_DOCSTRING_ADDENDUM}"
setattr(target_cls_dict[method_name], '__doc__', new_method_docstring)


class MappedDatasetMethodsMixin:
Expand All @@ -244,33 +278,28 @@ class MappedDatasetMethodsMixin:

Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
"""
# TODO equals, broadcast_equals etc.
# TODO do dask-related private methods need to be exposed?
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info',
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
# TODO unsure if these are called by external functions or not?
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
__slots__ = ()
_wrap_then_attach_to_cls(vars(), Dataset, _ALL_DATASET_METHODS_TO_MAP, wrap_func=map_over_subtree)

# TODO methods which should not or cannot act over the whole tree, such as .to_array

methods_to_expose = [(method_name, getattr(Dataset, method_name))
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
_wrap_then_attach_to_cls(vars(), methods_to_expose, wrap_func=map_over_subtree)
class MappedDataWithCoords(DataWithCoords):
# TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample
# TODO re-implement AttrsAccessMixin stuff so that it includes access to child nodes
_wrap_then_attach_to_cls(vars(), DataWithCoords, _DATA_WITH_COORDS_METHODS_TO_MAP, wrap_func=map_over_subtree)


# TODO implement ArrayReduce type methods
class DataTreeArithmetic(DatasetArithmetic):
"""
Mixin to add Dataset methods like __add__ and .mean()

Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine unaltered (normally
because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new
tree) and some will get overridden by the class definition of DataTree.
"""
_wrap_then_attach_to_cls(vars(), DatasetArithmetic, _ARITHMETIC_METHODS_TO_WRAP, wrap_func=map_over_subtree)


class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):
class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmetic):
"""
A tree-like hierarchical collection of xarray objects.

Expand Down Expand Up @@ -312,6 +341,8 @@ class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):

# TODO currently allows self.ds = None, should we instead always store at least an empty Dataset?

# TODO dataset methods which should not or cannot act over the whole tree, such as .to_array

def __init__(
self,
data_objects: Dict[PathType, Union[Dataset, DataArray]] = None,
Expand Down
62 changes: 55 additions & 7 deletions datatree/tests/test_dataset_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_properties(self):
assert dt.sizes == dt.ds.sizes
assert dt.variables == dt.ds.variables


def test_no_data_no_properties(self):
dt = DataNode('root', data=None)
with pytest.raises(AttributeError):
Expand All @@ -96,24 +95,73 @@ def test_no_data_no_properties(self):


class TestDSMethodInheritance:
def test_root(self):
def test_dataset_method(self):
# test root
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
dt = DataNode('root', data=da)
expected_ds = da.to_dataset().isel(x=1)
result_ds = dt.isel(x=1).ds
assert_equal(result_ds, expected_ds)

def test_descendants(self):
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
dt = DataNode('root')
# test descendant
DataNode('results', parent=dt, data=da)
expected_ds = da.to_dataset().isel(x=1)
result_ds = dt.isel(x=1)['results'].ds
assert_equal(result_ds, expected_ds)

def test_reduce_method(self):
# test root
da = xr.DataArray(name='a', data=[False, True, False], dims='x')
dt = DataNode('root', data=da)
expected_ds = da.to_dataset().any()
result_ds = dt.any().ds
assert_equal(result_ds, expected_ds)

# test descendant
DataNode('results', parent=dt, data=da)
result_ds = dt.any()['results'].ds
assert_equal(result_ds, expected_ds)

def test_nan_reduce_method(self):
# test root
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
dt = DataNode('root', data=da)
expected_ds = da.to_dataset().mean()
result_ds = dt.mean().ds
assert_equal(result_ds, expected_ds)

# test descendant
DataNode('results', parent=dt, data=da)
result_ds = dt.mean()['results'].ds
assert_equal(result_ds, expected_ds)

def test_cum_method(self):
# test root
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
dt = DataNode('root', data=da)
expected_ds = da.to_dataset().cumsum()
result_ds = dt.cumsum().ds
assert_equal(result_ds, expected_ds)

# test descendant
DataNode('results', parent=dt, data=da)
result_ds = dt.cumsum()['results'].ds
assert_equal(result_ds, expected_ds)


class TestOps:
...
@pytest.mark.xfail
def test_binary_op(self):
ds1 = xr.Dataset({'a': [5], 'b': [3]})
ds2 = xr.Dataset({'x': [0.1, 0.2], 'y': [10, 20]})
dt = DataNode('root', data=ds1)
DataNode('subnode', data=ds2, parent=dt)

expected_root = DataNode('root', data=ds1*ds1)
expected_descendant = DataNode('subnode', data=ds2*ds2, parent=expected_root)
result = dt * dt

assert_equal(result.ds, expected_root.ds)
assert_equal(result['subnode'].ds, expected_descendant.ds)


@pytest.mark.xfail
Expand Down