Skip to content

Commit

Permalink
Remove from_root
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Sep 10, 2024
1 parent 9095930 commit 5eee5fd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 42 deletions.
31 changes: 14 additions & 17 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,31 +1241,25 @@ def isomorphic(
def _matching(
self,
other: DataTree,
from_root: bool,
nodes_match: Callable[[DataTree, DataTree], bool],
) -> bool:
if not self.isomorphic(other, from_root=from_root, strict_names=True):
if not self.isomorphic(other, from_root=True, strict_names=True):
return False

return all(
nodes_match(node, other_node)
for node, other_node in zip(self.subtree, other.subtree)
for node, other_node in zip(self.descendants, other.descendants)
)

def equals(self, other: DataTree, from_root: bool = True) -> bool:
def equals(self, other: DataTree) -> bool:
"""
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
and if they have matching variables and coordinates, all of which are equal.
By default this method will check the whole tree above the given node.
Parameters
----------
other : DataTree
The other tree object to compare to.
from_root : bool, optional, default is True
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
If neither tree has a parent then this has no effect.
See Also
--------
Expand All @@ -1275,33 +1269,36 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool:
"""
to_ds = lambda x: x._to_dataset_view(inherited=True, rebuild_dims=False)
matcher = lambda x, y: to_ds(x).equals(to_ds(y))
return self._matching(other, from_root, nodes_match=matcher)
if not matcher(self, other):
return False
return self._matching(other, nodes_match=matcher)

def identical(self, other: DataTree, from_root=True) -> bool:
def identical(self, other: DataTree) -> bool:
"""
Like equals, but will also check all dataset attributes and the attributes on
all variables and coordinates, and requires the coordinates are defined at the
exact same levels of the DataTree hierarchy.
By default this method will check the whole tree above the given node.
Parameters
----------
other : DataTree
The other tree object to compare to.
from_root : bool, optional, default is True
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
If neither tree has a parent then this has no effect.
See Also
--------
Dataset.identical
DataTree.isomorphic
DataTree.equals
"""
# include inherited coordinates only at the root; otherwise require that
# everything is also defined at the same level
self_view = self._to_dataset_view(inherited=True, rebuild_dims=False)
other_view = other._to_dataset_view(inherited=True, rebuild_dims=False)
if not self_view.identical(other_view):
return False
to_ds = lambda x: x._to_dataset_view(inherited=False, rebuild_dims=False)
matcher = lambda x, y: to_ds(x).identical(to_ds(y))
return self._matching(other, from_root, nodes_match=matcher)
return self._matching(other, nodes_match=matcher)

def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree:
"""
Expand Down
28 changes: 5 additions & 23 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
def assert_equal(a, b, check_dim_order: bool = True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.
Expand All @@ -135,10 +135,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
or xarray.core.datatree.DataTree. The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
or xarray.core.datatree.DataTree. The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
Expand All @@ -159,11 +155,7 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
elif isinstance(a, Coordinates):
assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
assert a.equals(b), diff_datatree_repr(a, b, "equals")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")

Expand All @@ -173,11 +165,11 @@ def assert_identical(a, b): ...


@overload
def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...
def assert_identical(a: DataTree, b: DataTree): ...


@ensure_warnings
def assert_identical(a, b, from_root=True):
def assert_identical(a, b):
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
objects' names and attributes.
Expand All @@ -193,10 +185,6 @@ def assert_identical(a, b, from_root=True):
The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
Expand All @@ -220,13 +208,7 @@ def assert_identical(a, b, from_root=True):
elif isinstance(a, Coordinates):
assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.identical(b, from_root=from_root), diff_datatree_repr(
a, b, "identical"
)
assert a.identical(b), diff_datatree_repr(a, b, "identical")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")

Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def test_discard_ancestry(self, create_test_datatree):
def times_ten(ds):
return 10.0 * ds

expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"].copy()
result_tree = times_ten(subtree)
assert_equal(result_tree, expected, from_root=False)
assert_equal(result_tree, expected)

def test_skip_empty_nodes_with_attrs(self, create_test_datatree):
# inspired by xarray-datatree GH262
Expand Down

0 comments on commit 5eee5fd

Please sign in to comment.