Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement map_over_datasets using group_subtrees #9636

Merged
merged 25 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DATATREE_MIGRATION_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This guide is for previous users of the prototype `datatree.DataTree` class in t
> [!IMPORTANT]
> There are breaking changes! You should not expect that code written with `xarray-contrib/datatree` will work without any modifications. At the absolute minimum you will need to change the top-level import statement, but there are other changes too.

We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself, integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.
We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself; integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.

### Data model changes

Expand All @@ -17,6 +17,8 @@ These alignment checks happen at tree construction time, meaning there are some

The alignment checks allowed us to add "Coordinate Inheritance", a much-requested feature where indexed coordinate variables are now "inherited" down to child nodes. This allows you to define common coordinates in a parent group that are then automatically available on every child node. The distinction between a locally-defined coordinate variables and an inherited coordinate that was defined on a parent node is reflected in the `DataTree.__repr__`. Generally if you prefer not to have these variables be inherited you can get more similar behaviour to the old `datatree` package by removing indexes from coordinates, as this prevents inheritance.

Tree structure checks between multiple trees (i.e., `DataTree.isomorophic`) and pairing of nodes in arithmetic has also changed. Nodes are now matched (with `xarray.group_subtrees`) based on their relative paths, without regard to the order in which child nodes are defined.

For further documentation see the page in the user guide on Hierarchical Data.

### Integrated backends
Expand Down
12 changes: 11 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,17 @@ Manipulate the contents of a single ``DataTree`` node.
DataTree.assign
DataTree.drop_nodes

DataTree Operations
-------------------

Apply operations over multiple ``DataTree`` objects.

.. autosummary::
:toctree: generated/

map_over_datasets
group_subtrees

Comparisons
-----------

Expand Down Expand Up @@ -954,7 +965,6 @@ DataTree methods

open_datatree
open_groups
map_over_datasets
DataTree.to_dict
DataTree.to_netcdf
DataTree.to_zarr
Expand Down
88 changes: 70 additions & 18 deletions doc/user-guide/hierarchical-data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,21 +362,26 @@ This returns an iterable of nodes, which yields them in depth-first order.
for node in vertebrates.subtree:
print(node.path)

A very useful pattern is to use :py:class:`~xarray.DataTree.subtree` conjunction with the :py:class:`~xarray.DataTree.path` property to manipulate the nodes however you wish,
then rebuild a new tree using :py:meth:`xarray.DataTree.from_dict()`.
Similarly, :py:class:`~xarray.DataTree.subtree_with_keys` returns an iterable of
relative paths and corresponding nodes.

A very useful pattern is to iterate over :py:class:`~xarray.DataTree.subtree_with_keys`
to manipulate nodes however you wish, then rebuild a new tree using
:py:meth:`xarray.DataTree.from_dict()`.
Comment on lines +368 to +370
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we have a new property:

assert DataTree.from_dict(dt.subtree_with_keys()) == dt

Is there a good place we could document these properties?

For example, we could keep only the nodes containing data by looping over all nodes,
checking if they contain any data using :py:class:`~xarray.DataTree.has_data`,
then rebuilding a new tree using only the paths of those nodes:

.. ipython:: python

non_empty_nodes = {node.path: node.dataset for node in dt.subtree if node.has_data}
non_empty_nodes = {
path: node.dataset for path, node in dt.subtree_with_keys if node.has_data
}
xr.DataTree.from_dict(non_empty_nodes)

You can see this tree is similar to the ``dt`` object above, except that it is missing the empty nodes ``a/c`` and ``a/c/d``.

(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`~xarray.DataTree.from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.root.name)``.)
(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`~xarray.DataTree.from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.name)``.)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!


.. _manipulating trees:

Expand Down Expand Up @@ -573,38 +578,85 @@ Then calculate the RMS value of these signals:

.. _multiple trees:

We can also use the :py:meth:`~xarray.map_over_datasets` decorator to promote a function which accepts datasets into one which
accepts datatrees.
We can also use :py:func:`~xarray.map_over_datasets` apply a function over
trees appearing in any positional argument.
shoyer marked this conversation as resolved.
Show resolved Hide resolved

Operating on Multiple Trees
---------------------------

The examples so far have involved mapping functions or methods over the nodes of a single tree,
but we can generalize this to mapping functions over multiple trees at once.

Iterating Over Multiple Trees
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To iterate over the corresponding nodes in multiple trees, use
:py:func:`~xarray.group_subtrees` instead of
:py:class:`~xarray.DataTree.subtree_with_keys`:

.. ipython:: python

dt1 = xr.DataTree.from_dict({"a": xr.Dataset({"x": 1}), "b": xr.Dataset({"x": 2})})
dt2 = xr.DataTree.from_dict(
{"a": xr.Dataset({"x": 10}), "b": xr.Dataset({"x": 20})}
)
for path, (node1, node2) in xr.group_subtrees(dt1, dt2):
print(path, int(node1["x"]), int(node2["x"]))
shoyer marked this conversation as resolved.
Show resolved Hide resolved

To rebuild a tree after applying operations at each node, use
:py:meth:`xarray.DataTree.from_dict()`:

.. ipython:: python

result = {}
for path, (node1, node2) in xr.group_subtrees(dt1, dt2):
result[path] = node1.dataset + node2.dataset
xr.DataTree.from_dict(result)

Or apply a function directly to paired datasets at every node using
:py:func:`xarray.map_over_datasets`:

.. ipython:: python

xr.map_over_datasets(lambda x, y: x + y, dt1, dt2)

Comparing Trees for Isomorphism
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For it to make sense to map a single non-unary function over the nodes of multiple trees at once,
each tree needs to have the same structure. Specifically two trees can only be considered similar, or "isomorphic",
if they have the same number of nodes, and each corresponding node has the same number of children.
We can check if any two trees are isomorphic using the :py:meth:`~xarray.DataTree.isomorphic` method.
each tree needs to have the same structure. Specifically two trees can only be considered similar,
or "isomorphic", if the full paths to all of their descendent nodes are the same.

Applying :py:func:`~xarray.group_subtrees` to trees with different structure
shoyer marked this conversation as resolved.
Show resolved Hide resolved
raises :py:class:`~xarray.TreeIsomorphismError`:

.. ipython:: python
:okexcept:

dt1 = xr.DataTree.from_dict({"a": None, "a/b": None})
dt2 = xr.DataTree.from_dict({"a": None})
dt1.isomorphic(dt2)
tree = xr.DataTree.from_dict({"a": None, "a/b": None, "a/c": None})
simple_tree = xr.DataTree.from_dict({"a": None})
for _ in xr.group_subtrees(tree, simple_tree):
...

We can explicitly also check if any two trees are isomorphic using the :py:meth:`~xarray.DataTree.isomorphic` method:

.. ipython:: python

tree.isomorphic(simple_tree)

Corresponding tree nodes do not need to have the same data in order to be considered isomorphic:

.. ipython:: python

tree_with_data = xr.DataTree.from_dict({"a": xr.Dataset({"foo": 1})})
simple_tree.isomorphic(tree_with_data)

dt3 = xr.DataTree.from_dict({"a": None, "b": None})
dt1.isomorphic(dt3)
They also do not need to define child nodes in the same order:

dt4 = xr.DataTree.from_dict({"A": None, "A/B": xr.Dataset({"foo": 1})})
dt1.isomorphic(dt4)
.. ipython:: python

If the trees are not isomorphic a :py:class:`~xarray.TreeIsomorphismError` will be raised.
Notice that corresponding tree nodes do not need to have the same name or contain the same data in order to be considered isomorphic.
reordered_tree = xr.DataTree.from_dict({"a": None, "a/c": None, "a/b": None})
tree.isomorphic(reordered_tree)

Arithmetic Between Multiple Trees
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ New Features
~~~~~~~~~~~~
- ``DataTree`` related functionality is now exposed in the main ``xarray`` public
API. This includes: ``xarray.DataTree``, ``xarray.open_datatree``, ``xarray.open_groups``,
``xarray.map_over_datasets``, ``xarray.register_datatree_accessor`` and
``xarray.testing.assert_isomorphic``.
``xarray.map_over_datasets``, ``xarray.group_subtrees``,
``xarray.register_datatree_accessor`` and ``xarray.testing.assert_isomorphic``.
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_,
`Eni Awowale <https://github.com/eni-awowale>`_,
`Matt Savoie <https://github.com/flamingbear>`_,
Expand Down
10 changes: 8 additions & 2 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_datasets
from xarray.core.datatree_mapping import map_over_datasets
from xarray.core.extensions import (
register_dataarray_accessor,
register_dataset_accessor,
Expand All @@ -45,7 +45,12 @@
from xarray.core.merge import Context, MergeError, merge
from xarray.core.options import get_options, set_options
from xarray.core.parallel import map_blocks
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError
from xarray.core.treenode import (
InvalidTreeError,
NotFoundInTreeError,
TreeIsomorphismError,
group_subtrees,
)
from xarray.core.variable import IndexVariable, Variable, as_variable
from xarray.namedarray.core import NamedArray
from xarray.util.print_versions import show_versions
Expand Down Expand Up @@ -82,6 +87,7 @@
"cross",
"full_like",
"get_options",
"group_subtrees",
"infer_freq",
"load_dataarray",
"load_dataset",
Expand Down
15 changes: 1 addition & 14 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand All @@ -46,7 +46,6 @@
MissingCoreDimOptions = Literal["raise", "copy", "drop"]

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})


Expand Down Expand Up @@ -186,18 +185,6 @@ def _enumerate(dim):
return str(alt_signature)


def result_name(objects: Iterable[Any]) -> Any:
# use the same naming heuristics as pandas:
# https://github.com/blaze/blaze/issues/458#issuecomment-51936356
names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects}
names.discard(_DEFAULT_NAME)
if len(names) == 1:
(name,) = names
else:
name = None
return name


def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]:
coords_list = []
for arg in args:
Expand Down
12 changes: 2 additions & 10 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
either_dict_or_kwargs,
hashable,
infix_dims,
result_name,
)
from xarray.core.variable import (
IndexVariable,
Expand Down Expand Up @@ -4726,15 +4727,6 @@ def identical(self, other: Self) -> bool:
except (TypeError, AttributeError):
return False

def _result_name(self, other: Any = None) -> Hashable | None:
# use the same naming heuristics as pandas:
# https://github.com/ContinuumIO/blaze/issues/458#issuecomment-51936356
other_name = getattr(other, "name", _default)
if other_name is _default or other_name == self.name:
return self.name
else:
return None

def __array_wrap__(self, obj, context=None) -> Self:
new_var = self.variable.__array_wrap__(obj, context)
return self._replace(new_var)
Expand Down Expand Up @@ -4782,7 +4774,7 @@ def _binary_op(
else f(other_variable_or_arraylike, self.variable)
)
coords, indexes = self.coords._merge_raw(other_coords, reflexive)
name = self._result_name(other)
name = result_name([self, other])

return self._replace(variable, coords, name, indexes=indexes)

Expand Down
Loading
Loading