Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DataTree.isel and DataTree.sel #9588

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -761,16 +761,17 @@ Compare one ``DataTree`` object to another.
DataTree.equals
DataTree.identical

.. Indexing
.. --------
Indexing
--------

.. Index into all nodes in the subtree simultaneously.
Index into all nodes in the subtree simultaneously.

.. .. autosummary::
.. :toctree: generated/
.. autosummary::
:toctree: generated/

DataTree.isel
DataTree.sel

.. DataTree.isel
.. DataTree.sel
.. DataTree.drop_sel
.. DataTree.drop_isel
.. DataTree.head
Expand Down
188 changes: 186 additions & 2 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.treenode import NamedNode, NodePath
from xarray.core.types import Self
from xarray.core.utils import (
Default,
FilteredMapping,
Frozen,
_default,
drop_dims_from_indexers,
either_dict_or_kwargs,
maybe_wrap_array,
)
Expand All @@ -54,7 +56,12 @@

from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
from xarray.core.merge import CoercibleMapping, CoercibleValue
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes
from xarray.core.types import (
ErrorOptions,
ErrorOptionsWithWarn,
NetcdfWriteModes,
ZarrWriteModes,
)

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -1087,7 +1094,7 @@ def from_dict(
d: Mapping[str, Dataset | DataTree | None],
/,
name: str | None = None,
) -> DataTree:
) -> Self:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.

Expand Down Expand Up @@ -1607,3 +1614,180 @@ def to_zarr(
compute=compute,
**kwargs,
)

def _selective_indexing(
self,
func: Callable[[Dataset, Mapping[Any, Any]], Dataset],
indexers: Mapping[Any, Any],
missing_dims: ErrorOptionsWithWarn = "raise",
) -> Self:
shoyer marked this conversation as resolved.
Show resolved Hide resolved
"""Apply an indexing operation over the subtree, handling missing
dimensions and inherited coordinates gracefully by only applying
indexing at each node selectively.
"""
all_dims = set()
for node in self.subtree:
all_dims.update(node._node_dims)
indexers = drop_dims_from_indexers(indexers, all_dims, missing_dims)
shoyer marked this conversation as resolved.
Show resolved Hide resolved

result = {}
for node in self.subtree:
node_indexers = {k: v for k, v in indexers.items() if k in node.dims}
node_result = func(node.dataset, node_indexers)
# Indexing datasets corresponding to each node results in redundant
# coordinates when indexes from a parent node are inherited.
# Ideally, we would avoid creating such coordinates in the first
# place, but that would require implementing indexing operations at
# the Variable instead of the Dataset level.
for k in node_indexers:
if k not in node._node_coord_variables and k in node_result.coords:
# We remove all inherited coordinates. Coordinates
# corresponding to an index would be de-duplicated by
# _deduplicate_inherited_coordinates(), but indexing (e.g.,
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
result[node.path] = node_result
return type(self).from_dict(result, name=self.name)

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
drop: bool = False,
missing_dims: ErrorOptionsWithWarn = "raise",
**indexers_kwargs: Any,
) -> Self:
"""Returns a new data tree with each array indexed along the specified
dimension(s).

This method selects values from each array using its `__getitem__`
method, except this method does not require knowing the order of
each array's dimensions.

Parameters
----------
indexers : dict, optional
A dict with keys matching dimensions and values given
by integers, slice objects or arrays.
indexer can be a integer, slice, array-like or DataArray.
If DataArrays are passed as indexers, xarray-style indexing will be
carried out. See :ref:`indexing` for the details.
One of indexers or indexers_kwargs must be provided.
drop : bool, default: False
If ``drop=True``, drop coordinates variables indexed by integers
instead of making them scalar.
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
What to do if dimensions that should be selected from are not present in the
Dataset:
- "raise": raise an exception
- "warn": raise a warning, and ignore the missing dimensions
- "ignore": ignore the missing dimensions

**indexers_kwargs : {dim: indexer, ...}, optional
The keyword arguments form of ``indexers``.
One of indexers or indexers_kwargs must be provided.

Returns
-------
obj : DataTree
A new DataTree with the same contents as this data tree, except each
array and dimension is indexed by the appropriate indexers.
If indexer DataArrays have coordinates that do not conflict with
this object, then these coordinates will be attached.
In general, each array's data will be a view of the array's data
in this dataset, unless vectorized indexing was triggered by using
an array indexer, in which case the data will be a copy.

See Also
--------
DataTree.sel
Dataset.isel
"""

def apply_indexers(dataset, node_indexers):
return dataset.isel(node_indexers, drop=drop)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
return self._selective_indexing(
apply_indexers, indexers, missing_dims=missing_dims
)

def sel(
self,
indexers: Mapping[Any, Any] | None = None,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
drop: bool = False,
**indexers_kwargs: Any,
) -> Self:
"""Returns a new data tree with each array indexed by tick labels
along the specified dimension(s).

In contrast to `DataTree.isel`, indexers for this method should use
labels instead of integers.

Under the hood, this method is powered by using pandas's powerful Index
objects. This makes label based indexing essentially just as fast as
using integer indexing.

It also means this method uses pandas's (well documented) logic for
indexing. This means you can use string shortcuts for datetime indexes
(e.g., '2000-01' to select all values in January 2000). It also means
that slices are treated as inclusive of both the start and stop values,
unlike normal Python indexing.

Parameters
----------
indexers : dict, optional
A dict with keys matching dimensions and values given
by scalars, slices or arrays of tick labels. For dimensions with
multi-index, the indexer may also be a dict-like object with keys
matching index level names.
If DataArrays are passed as indexers, xarray-style indexing will be
carried out. See :ref:`indexing` for the details.
One of indexers or indexers_kwargs must be provided.
method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional
Method to use for inexact matches:

* None (default): only exact matches
* pad / ffill: propagate last valid index value forward
* backfill / bfill: propagate next valid index value backward
* nearest: use nearest valid index value
tolerance : optional
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
drop : bool, optional
If ``drop=True``, drop coordinates variables in `indexers` instead
of making them scalar.
**indexers_kwargs : {dim: indexer, ...}, optional
The keyword arguments form of ``indexers``.
One of indexers or indexers_kwargs must be provided.

Returns
-------
obj : DataTree
A new DataTree with the same contents as this data tree, except each
variable and dimension is indexed by the appropriate indexers.
If indexer DataArrays have coordinates that do not conflict with
this object, then these coordinates will be attached.
In general, each array's data will be a view of the array's data
in this dataset, unless vectorized indexing was triggered by using
an array indexer, in which case the data will be a copy.

See Also
--------
DataTree.isel
Dataset.sel
"""

def apply_indexers(dataset, node_indexers):
# TODO: reimplement in terms of map_index_queries(), to avoid
# redundant look-ups of integer positions from labels (via indexes)
# on child nodes.
return dataset.sel(
node_indexers, method=method, tolerance=tolerance, drop=drop
)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
return self._selective_indexing(apply_indexers, indexers)
92 changes: 80 additions & 12 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,6 @@ def test_ipython_key_completions(self, create_test_datatree):
var_keys = list(dt.variables.keys())
assert all(var_key in key_completions for var_key in var_keys)

@pytest.mark.xfail(reason="sel not implemented yet")
def test_operation_with_attrs_but_no_data(self):
# tests bug from xarray-datatree GH262
xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))})
Expand Down Expand Up @@ -1561,26 +1560,95 @@ def test_filter(self):
assert_identical(elders, expected)


class TestDSMethodInheritance:
@pytest.mark.xfail(reason="isel not implemented yet")
def test_dataset_method(self):
ds = xr.Dataset({"a": ("x", [1, 2, 3])})
dt = DataTree.from_dict(
class TestIndexing:

def test_isel_siblings(self):
tree = DataTree.from_dict(
{
"/": ds,
"/results": ds,
"/first": xr.Dataset({"a": ("x", [1, 2])}),
"/second": xr.Dataset({"b": ("x", [1, 2, 3])}),
}
)

expected = DataTree.from_dict(
{
"/": ds.isel(x=1),
"/results": ds.isel(x=1),
"/first": xr.Dataset({"a": 2}),
"/second": xr.Dataset({"b": 3}),
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)

result = dt.isel(x=1)
assert_equal(result, expected)
expected = DataTree.from_dict(
{
"/first": xr.Dataset({"a": ("x", [1])}),
"/second": xr.Dataset({"b": ("x", [1])}),
}
)
actual = tree.isel(x=slice(1))
assert_equal(actual, expected)

actual = tree.isel(x=[0])
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

def test_isel_inherited(self):
tree = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2]}),
"/child": xr.Dataset({"foo": ("x", [3, 4])}),
}
)

expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": 2}),
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)
Comment on lines +1605 to +1612
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this passes the x coordinate gets duplicated doesn't it? As the result is a non-indexed scalar. Does that mean this would have failed if you had used the definition of assert_identical from #9473 to check for duplication?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this should also pass

actual = tree.isel(x=[0])
assert_identical(actual, expected)

because the result will have inherited instead of duplicated coordinates instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote special logic for removing duplicated scalar coordinates to handle just this!

There is no longer a need to implement a different version of inheritance for assert_identical than for assert_equal, because we no longer allow any cases in the DataTree data model where inheritance is ambiguous:

  • Deduplication of coordinates with an index removes them from child nodes.
  • Other coordinates are no longer inherited.


expected = DataTree.from_dict(
{
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1, drop=True)
assert_equal(actual, expected)

expected = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1]}),
"/child": xr.Dataset({"foo": ("x", [3])}),
}
)
actual = tree.isel(x=[0])
assert_equal(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)

def test_sel(self):
tree = DataTree.from_dict(
{
"/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}),
"/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}),
}
)
expected = DataTree.from_dict(
{
"/first": xr.Dataset({"a": 2}, coords={"x": 2}),
"/second": xr.Dataset({"b": 4}, coords={"x": 2}),
}
)
actual = tree.sel(x=2)
assert_equal(actual, expected)


class TestDSMethodInheritance:

@pytest.mark.xfail(reason="reduce methods not implemented yet")
def test_reduce_method(self):
Expand Down
Loading