-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
XarrayVariableEqualityComparator
- Loading branch information
1 parent
bb31843
commit 7b59e3c
Showing
6 changed files
with
337 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
r"""Implement an equality comparator for ``xarray`` objects.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = ["XarrayVariableEqualityComparator", "get_type_comparator_mapping"] | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Any | ||
from unittest.mock import Mock | ||
|
||
from coola.equality.comparators.base import BaseEqualityComparator | ||
from coola.equality.handlers import ( | ||
SameAttributeHandler, | ||
SameDataHandler, | ||
SameObjectHandler, | ||
SameTypeHandler, | ||
TrueHandler, | ||
) | ||
from coola.utils import check_xarray, is_xarray_available | ||
|
||
if is_xarray_available(): | ||
import xarray as xr | ||
else: # pragma: no cover | ||
xr = Mock() | ||
|
||
if TYPE_CHECKING: | ||
from coola.equality import EqualityConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class XarrayVariableEqualityComparator(BaseEqualityComparator[xr.Variable]): | ||
r"""Implement an equality comparator for ``xarray.Variable``. | ||
Example usage: | ||
```pycon | ||
>>> import numpy as np | ||
>>> import xarray as xr | ||
>>> from coola.equality import EqualityConfig | ||
>>> from coola.equality.comparators import NumpyArrayEqualityComparator | ||
>>> from coola.testers import EqualityTester | ||
>>> config = EqualityConfig(tester=EqualityTester()) | ||
>>> comparator = NumpyArrayEqualityComparator() | ||
>>> comparator.equal( | ||
... xr.Variable(dims=["z"], data=np.arange(6)), | ||
... xr.Variable(dims=["z"], data=np.arange(6)), | ||
... config, | ||
... ) | ||
True | ||
>>> comparator.equal( | ||
... xr.Variable(dims=["z"], data=np.zeros(6)), | ||
... xr.Variable(dims=["z"], data=np.ones(6)), | ||
... config, | ||
... ) | ||
False | ||
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
check_xarray() | ||
self._handler = SameObjectHandler() | ||
self._handler.chain(SameTypeHandler()).chain(SameDataHandler()).chain( | ||
SameAttributeHandler(name="dims") | ||
).chain(SameAttributeHandler(name="attrs")).chain(TrueHandler()) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
return isinstance(other, self.__class__) | ||
|
||
def clone(self) -> XarrayVariableEqualityComparator: | ||
return self.__class__() | ||
|
||
def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool: | ||
return self._handler.handle(object1=object1, object2=object2, config=config) | ||
|
||
|
||
def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: | ||
r"""Get a default mapping between the types and the equality | ||
comparators. | ||
This function returns an empty dictionary if ``xarray`` is not | ||
installed. | ||
Returns: | ||
The mapping between the types and the equality comparators. | ||
Example usage: | ||
```pycon | ||
>>> from coola.equality.comparators.xarray_ import get_type_comparator_mapping | ||
>>> get_type_comparator_mapping() | ||
{<class 'xarray.core.variable.Variable'>: XarrayVariableEqualityComparator()} | ||
``` | ||
""" | ||
if not is_xarray_available(): | ||
return {} | ||
return {xr.Variable: XarrayVariableEqualityComparator()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from coola import objects_are_equal | ||
from coola.equality import EqualityConfig | ||
from coola.equality.comparators.xarray_ import ( | ||
XarrayVariableEqualityComparator, | ||
get_type_comparator_mapping, | ||
) | ||
from coola.testers import EqualityTester | ||
from coola.testing import xarray_available | ||
from coola.utils.imports import is_numpy_available, is_xarray_available | ||
|
||
if is_numpy_available(): | ||
import numpy as np | ||
else: | ||
np = Mock() | ||
|
||
if is_xarray_available(): | ||
import xarray as xr | ||
else: | ||
xr = Mock() | ||
|
||
|
||
@pytest.fixture() | ||
def config() -> EqualityConfig: | ||
return EqualityConfig(tester=EqualityTester()) | ||
|
||
|
||
###################################################### | ||
# Tests for XarrayVariableEqualityComparator # | ||
###################################################### | ||
|
||
|
||
@xarray_available | ||
def test_objects_are_equal_variable() -> None: | ||
assert objects_are_equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), xr.Variable(dims=["z"], data=np.arange(6)) | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_str() -> None: | ||
assert str(XarrayVariableEqualityComparator()).startswith("XarrayVariableEqualityComparator(") | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator__eq__true() -> None: | ||
assert XarrayVariableEqualityComparator() == XarrayVariableEqualityComparator() | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator__eq__false() -> None: | ||
assert XarrayVariableEqualityComparator() != 123 | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_clone() -> None: | ||
op = XarrayVariableEqualityComparator() | ||
op_cloned = op.clone() | ||
assert op is not op_cloned | ||
assert op == op_cloned | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_true(config: EqualityConfig) -> None: | ||
assert XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), | ||
xr.Variable(dims=["z"], data=np.arange(6)), | ||
config, | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_true_same_object(config: EqualityConfig) -> None: | ||
obj = xr.Variable(dims=["z"], data=np.arange(6)) | ||
assert XarrayVariableEqualityComparator().equal(obj, obj, config) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_true_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = XarrayVariableEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert comparator.equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), | ||
xr.Variable(dims=["z"], data=np.arange(6)), | ||
config, | ||
) | ||
assert not caplog.messages | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_data(config: EqualityConfig) -> None: | ||
assert not XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.ones(6)), | ||
xr.Variable(dims=["z"], data=np.zeros(6)), | ||
config, | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_data_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = XarrayVariableEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
xr.Variable(dims=["z"], data=np.ones(6)), | ||
xr.Variable(dims=["z"], data=np.zeros(6)), | ||
config, | ||
) | ||
assert caplog.messages[-1].startswith("objects have different data:") | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_dims(config: EqualityConfig) -> None: | ||
assert not XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), | ||
xr.Variable(dims=["x"], data=np.arange(6)), | ||
config, | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_dims_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = XarrayVariableEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), | ||
xr.Variable(dims=["x"], data=np.arange(6)), | ||
config, | ||
) | ||
assert caplog.messages[-1].startswith("objects have different dims:") | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_different_attrs(config: EqualityConfig) -> None: | ||
assert not XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meow"}), | ||
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meoowww"}), | ||
config, | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_attrs_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = XarrayVariableEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meow"}), | ||
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meoowww"}), | ||
config, | ||
) | ||
assert caplog.messages[-1].startswith("objects have different attrs:") | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_different_type(config: EqualityConfig) -> None: | ||
assert not XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), np.arange(6), config | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_false_different_type_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = XarrayVariableEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
xr.Variable(dims=["z"], data=np.arange(6)), np.arange(6), config | ||
) | ||
assert caplog.messages[0].startswith("objects have different types:") | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_no_xarray() -> None: | ||
with patch( | ||
"coola.utils.imports.is_xarray_available", lambda *args, **kwargs: False | ||
), pytest.raises(RuntimeError, match="`xarray` package is required but not installed."): | ||
XarrayVariableEqualityComparator() | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_equal_nan_false(config: EqualityConfig) -> None: | ||
assert not XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), | ||
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), | ||
config, | ||
) | ||
|
||
|
||
@xarray_available | ||
def test_variable_equality_operator_equal_equal_nan_true(config: EqualityConfig) -> None: | ||
config.equal_nan = True | ||
# TODO(TIBO): update after the new version is finished # noqa: TD003 | ||
assert not XarrayVariableEqualityComparator().equal( | ||
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), | ||
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), | ||
config, | ||
) | ||
|
||
|
||
########################################## | ||
# Tests for get_mapping_equality # | ||
########################################## | ||
|
||
|
||
@xarray_available | ||
def test_get_type_comparator_mapping() -> None: | ||
assert get_type_comparator_mapping() == {xr.Variable: XarrayVariableEqualityComparator()} | ||
|
||
|
||
def test_get_type_comparator_mapping_no_xarray() -> None: | ||
with patch( | ||
"coola.equality.comparators.xarray_.is_xarray_available", lambda *args, **kwargs: False | ||
): | ||
assert get_type_comparator_mapping() == {} |