Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine DataTree.equals and DataTree.identical #9473

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,65 +1238,67 @@ def isomorphic(
except (TypeError, TreeIsomorphismError):
return False

def equals(self, other: DataTree, from_root: bool = True) -> bool:
def _matching(
self,
other: DataTree,
nodes_match: Callable[[DataTree, DataTree], bool],
) -> bool:
if not self.isomorphic(other, from_root=True, strict_names=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there is a problem with removing the from_root argument but keeping it fixed to True inside check_isomorphic...

Let's imagines I assert_equal two subtrees that have different parents. All the nodes in the subtrees could match, but if the structures above the node where the subtree starts are not isomorphic, it will return False. That doesn't seem like desired behaviour to me - it leaves no way to check that the two subtrees are equal.

return False

return all(
nodes_match(node, other_node)
for node, other_node in zip(self.descendants, other.descendants)
)

def equals(self, other: DataTree) -> bool:
"""
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
and if they have matching variables and coordinates, all of which are equal.

By default this method will check the whole tree above the given node.

Parameters
----------
other : DataTree
The other tree object to compare to.
from_root : bool, optional, default is True
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
If neither tree has a parent then this has no effect.

See Also
--------
Dataset.equals
DataTree.isomorphic
DataTree.identical
"""
if not self.isomorphic(other, from_root=from_root, strict_names=True):
to_ds = lambda x: x._to_dataset_view(inherited=True, rebuild_dims=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be simpler (and clearer) to use inherited=False? The nodes are checked in depth-first order anyway, so if there is a discrepancy between inherited coordinates it will be discovered at the top of the subtree before checking the children.

matcher = lambda x, y: to_ds(x).equals(to_ds(y))
if not matcher(self, other):
return False
return self._matching(other, nodes_match=matcher)

return all(
[
node.ds.equals(other_node.ds)
for node, other_node in zip(self.subtree, other.subtree)
]
)

def identical(self, other: DataTree, from_root=True) -> bool:
def identical(self, other: DataTree) -> bool:
"""
Like equals, but will also check all dataset attributes and the attributes on
all variables and coordinates.

By default this method will check the whole tree above the given node.
all variables and coordinates, and requires the coordinates are defined at the
exact same levels of the DataTree hierarchy.

Parameters
----------
other : DataTree
The other tree object to compare to.
from_root : bool, optional, default is True
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
If neither tree has a parent then this has no effect.

See Also
--------
Dataset.identical
DataTree.isomorphic
DataTree.equals
"""
if not self.isomorphic(other, from_root=from_root, strict_names=True):
# include inherited coordinates only at the root; otherwise require that
# everything is also defined at the same level
Comment on lines +1293 to +1294
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do this for identical but not for equals?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you did talk about this in your comment above

self_view = self._to_dataset_view(inherited=True, rebuild_dims=False)
other_view = other._to_dataset_view(inherited=True, rebuild_dims=False)
if not self_view.identical(other_view):
return False

return all(
node.ds.identical(other_node.ds)
for node, other_node in zip(self.subtree, other.subtree)
)
to_ds = lambda x: x._to_dataset_view(inherited=False, rebuild_dims=False)
matcher = lambda x, y: to_ds(x).identical(to_ds(y))
return self._matching(other, nodes_match=matcher)

def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree:
"""
Expand Down
28 changes: 5 additions & 23 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
def assert_equal(a, b, check_dim_order: bool = True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.

Expand All @@ -135,10 +135,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
or xarray.core.datatree.DataTree. The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
or xarray.core.datatree.DataTree. The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.

Expand All @@ -159,11 +155,7 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
elif isinstance(a, Coordinates):
assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
assert a.equals(b), diff_datatree_repr(a, b, "equals")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")

Expand All @@ -173,11 +165,11 @@ def assert_identical(a, b): ...


@overload
def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...
def assert_identical(a: DataTree, b: DataTree): ...


@ensure_warnings
def assert_identical(a, b, from_root=True):
def assert_identical(a, b):
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
objects' names and attributes.

Expand All @@ -193,10 +185,6 @@ def assert_identical(a, b, from_root=True):
The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.

Expand All @@ -220,13 +208,7 @@ def assert_identical(a, b, from_root=True):
elif isinstance(a, Coordinates):
assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.identical(b, from_root=from_root), diff_datatree_repr(
a, b, "identical"
)
assert a.identical(b), diff_datatree_repr(a, b, "identical")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")

Expand Down
56 changes: 56 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,62 @@ def f(x, tree, y):
assert actual is dt and actual.attrs == attrs


class TestEqualsAndIdentical:
def test_basics(self, create_test_datatree):
dt = create_test_datatree()
assert dt.equals(dt)
assert dt.identical(dt)

other = DataTree()
assert not dt.equals(other)
assert not dt.identical(other)
assert not other.equals(dt)
assert not other.identical(dt)

def test_isomormorphic(self, create_test_datatree):
dt = create_test_datatree()

dt_new_node = dt.copy()
dt_new_node["/something_new"] = DataTree()
assert not dt.equals(dt_new_node)
assert not dt.identical(dt_new_node)
assert not dt_new_node.equals(dt)
assert not dt_new_node.identical(dt)

dt_new_array = dt.copy()
dt_new_array["/something_else"] = xr.DataArray(1234)
assert not dt.equals(dt_new_array)
assert not dt.identical(dt_new_array)
assert not dt_new_array.equals(dt)
assert not dt_new_array.identical(dt)

def test_equal_but_not_identical_inheritance(self):
duplicated_coords = DataTree.from_dict(
{
"/": Dataset(coords={"x": [1]}),
"/sub": Dataset(coords={"x": [1]}),
}
)
inherited_coords = DataTree.from_dict(
{
"/": Dataset(coords={"x": [1]}),
"/sub": Dataset(),
}
)
assert duplicated_coords.equals(inherited_coords)
assert inherited_coords.equals(duplicated_coords)
assert not duplicated_coords.identical(inherited_coords)
assert not inherited_coords.identical(duplicated_coords)

def test_equal_but_not_identical_attrs(self):
without_attrs = DataTree()
with_attrs = DataTree(Dataset(attrs={"foo": "bar"}))
assert without_attrs.equals(with_attrs)
assert with_attrs.equals(without_attrs)
assert not without_attrs.identical(with_attrs)
assert not with_attrs.identical(without_attrs)


class TestSubset:
def test_match(self):
# TODO is this example going to cause problems with case sensitivity?
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def test_discard_ancestry(self, create_test_datatree):
def times_ten(ds):
return 10.0 * ds

expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"].copy()
result_tree = times_ten(subtree)
assert_equal(result_tree, expected, from_root=False)
assert_equal(result_tree, expected)

def test_skip_empty_nodes_with_attrs(self, create_test_datatree):
# inspired by xarray-datatree GH262
Expand Down
Loading