From 7b59e3c176c2f9806f314fddc5f18662af240d96 Mon Sep 17 00:00:00 2001 From: Thibaut Durand Date: Sun, 7 Jan 2024 14:06:42 -0800 Subject: [PATCH] Add `XarrayVariableEqualityComparator` --- src/coola/equality/comparators/__init__.py | 2 + src/coola/equality/comparators/jax_.py | 2 +- src/coola/equality/comparators/numpy_.py | 2 +- src/coola/equality/comparators/torch_.py | 2 +- src/coola/equality/comparators/xarray_.py | 99 ++++++++ .../unit/equality/comparators/test_xarray.py | 233 ++++++++++++++++++ 6 files changed, 337 insertions(+), 3 deletions(-) create mode 100644 src/coola/equality/comparators/xarray_.py create mode 100644 tests/unit/equality/comparators/test_xarray.py diff --git a/src/coola/equality/comparators/__init__.py b/src/coola/equality/comparators/__init__.py index 9508cb9f..18c3c8e0 100644 --- a/src/coola/equality/comparators/__init__.py +++ b/src/coola/equality/comparators/__init__.py @@ -11,6 +11,7 @@ "SequenceEqualityComparator", "TorchPackedSequenceEqualityComparator", "TorchTensorEqualityComparator", + "XarrayVariableEqualityComparator", "get_type_comparator_mapping", ] @@ -27,3 +28,4 @@ TorchTensorEqualityComparator, ) from coola.equality.comparators.utils import get_type_comparator_mapping +from coola.equality.comparators.xarray_ import XarrayVariableEqualityComparator diff --git a/src/coola/equality/comparators/jax_.py b/src/coola/equality/comparators/jax_.py index 916a2fe5..b48d00f1 100644 --- a/src/coola/equality/comparators/jax_.py +++ b/src/coola/equality/comparators/jax_.py @@ -68,7 +68,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: r"""Get a default mapping between the types and the equality comparators. - This function returns an empty dictionary if jax is not + This function returns an empty dictionary if ``jax`` is not installed. Returns: diff --git a/src/coola/equality/comparators/numpy_.py b/src/coola/equality/comparators/numpy_.py index 5bfdf6ec..ffded244 100644 --- a/src/coola/equality/comparators/numpy_.py +++ b/src/coola/equality/comparators/numpy_.py @@ -67,7 +67,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: r"""Get a default mapping between the types and the equality comparators. - This function returns an empty dictionary if numpy is not + This function returns an empty dictionary if ``numpy`` is not installed. Returns: diff --git a/src/coola/equality/comparators/torch_.py b/src/coola/equality/comparators/torch_.py index d1348425..ebe30abe 100644 --- a/src/coola/equality/comparators/torch_.py +++ b/src/coola/equality/comparators/torch_.py @@ -126,7 +126,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: r"""Get a default mapping between the types and the equality comparators. - This function returns an empty dictionary if torch is not + This function returns an empty dictionary if ``torch`` is not installed. Returns: diff --git a/src/coola/equality/comparators/xarray_.py b/src/coola/equality/comparators/xarray_.py new file mode 100644 index 00000000..9d2990a3 --- /dev/null +++ b/src/coola/equality/comparators/xarray_.py @@ -0,0 +1,99 @@ +r"""Implement an equality comparator for ``xarray`` objects.""" + +from __future__ import annotations + +__all__ = ["XarrayVariableEqualityComparator", "get_type_comparator_mapping"] + +import logging +from typing import TYPE_CHECKING, Any +from unittest.mock import Mock + +from coola.equality.comparators.base import BaseEqualityComparator +from coola.equality.handlers import ( + SameAttributeHandler, + SameDataHandler, + SameObjectHandler, + SameTypeHandler, + TrueHandler, +) +from coola.utils import check_xarray, is_xarray_available + +if is_xarray_available(): + import xarray as xr +else: # pragma: no cover + xr = Mock() + +if TYPE_CHECKING: + from coola.equality import EqualityConfig + +logger = logging.getLogger(__name__) + + +class XarrayVariableEqualityComparator(BaseEqualityComparator[xr.Variable]): + r"""Implement an equality comparator for ``xarray.Variable``. + + Example usage: + + ```pycon + >>> import numpy as np + >>> import xarray as xr + >>> from coola.equality import EqualityConfig + >>> from coola.equality.comparators import NumpyArrayEqualityComparator + >>> from coola.testers import EqualityTester + >>> config = EqualityConfig(tester=EqualityTester()) + >>> comparator = NumpyArrayEqualityComparator() + >>> comparator.equal( + ... xr.Variable(dims=["z"], data=np.arange(6)), + ... xr.Variable(dims=["z"], data=np.arange(6)), + ... config, + ... ) + True + >>> comparator.equal( + ... xr.Variable(dims=["z"], data=np.zeros(6)), + ... xr.Variable(dims=["z"], data=np.ones(6)), + ... config, + ... ) + False + + ``` + """ + + def __init__(self) -> None: + check_xarray() + self._handler = SameObjectHandler() + self._handler.chain(SameTypeHandler()).chain(SameDataHandler()).chain( + SameAttributeHandler(name="dims") + ).chain(SameAttributeHandler(name="attrs")).chain(TrueHandler()) + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def clone(self) -> XarrayVariableEqualityComparator: + return self.__class__() + + def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool: + return self._handler.handle(object1=object1, object2=object2, config=config) + + +def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: + r"""Get a default mapping between the types and the equality + comparators. + + This function returns an empty dictionary if ``xarray`` is not + installed. + + Returns: + The mapping between the types and the equality comparators. + + Example usage: + + ```pycon + >>> from coola.equality.comparators.xarray_ import get_type_comparator_mapping + >>> get_type_comparator_mapping() + {: XarrayVariableEqualityComparator()} + + ``` + """ + if not is_xarray_available(): + return {} + return {xr.Variable: XarrayVariableEqualityComparator()} diff --git a/tests/unit/equality/comparators/test_xarray.py b/tests/unit/equality/comparators/test_xarray.py new file mode 100644 index 00000000..ded4e63d --- /dev/null +++ b/tests/unit/equality/comparators/test_xarray.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +from unittest.mock import Mock, patch + +import pytest + +from coola import objects_are_equal +from coola.equality import EqualityConfig +from coola.equality.comparators.xarray_ import ( + XarrayVariableEqualityComparator, + get_type_comparator_mapping, +) +from coola.testers import EqualityTester +from coola.testing import xarray_available +from coola.utils.imports import is_numpy_available, is_xarray_available + +if is_numpy_available(): + import numpy as np +else: + np = Mock() + +if is_xarray_available(): + import xarray as xr +else: + xr = Mock() + + +@pytest.fixture() +def config() -> EqualityConfig: + return EqualityConfig(tester=EqualityTester()) + + +###################################################### +# Tests for XarrayVariableEqualityComparator # +###################################################### + + +@xarray_available +def test_objects_are_equal_variable() -> None: + assert objects_are_equal( + xr.Variable(dims=["z"], data=np.arange(6)), xr.Variable(dims=["z"], data=np.arange(6)) + ) + + +@xarray_available +def test_variable_equality_operator_str() -> None: + assert str(XarrayVariableEqualityComparator()).startswith("XarrayVariableEqualityComparator(") + + +@xarray_available +def test_variable_equality_operator__eq__true() -> None: + assert XarrayVariableEqualityComparator() == XarrayVariableEqualityComparator() + + +@xarray_available +def test_variable_equality_operator__eq__false() -> None: + assert XarrayVariableEqualityComparator() != 123 + + +@xarray_available +def test_variable_equality_operator_clone() -> None: + op = XarrayVariableEqualityComparator() + op_cloned = op.clone() + assert op is not op_cloned + assert op == op_cloned + + +@xarray_available +def test_variable_equality_operator_equal_true(config: EqualityConfig) -> None: + assert XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.arange(6)), + xr.Variable(dims=["z"], data=np.arange(6)), + config, + ) + + +@xarray_available +def test_variable_equality_operator_equal_true_same_object(config: EqualityConfig) -> None: + obj = xr.Variable(dims=["z"], data=np.arange(6)) + assert XarrayVariableEqualityComparator().equal(obj, obj, config) + + +@xarray_available +def test_variable_equality_operator_equal_true_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = XarrayVariableEqualityComparator() + with caplog.at_level(logging.INFO): + assert comparator.equal( + xr.Variable(dims=["z"], data=np.arange(6)), + xr.Variable(dims=["z"], data=np.arange(6)), + config, + ) + assert not caplog.messages + + +@xarray_available +def test_variable_equality_operator_equal_false_data(config: EqualityConfig) -> None: + assert not XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.ones(6)), + xr.Variable(dims=["z"], data=np.zeros(6)), + config, + ) + + +@xarray_available +def test_variable_equality_operator_equal_false_data_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = XarrayVariableEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + xr.Variable(dims=["z"], data=np.ones(6)), + xr.Variable(dims=["z"], data=np.zeros(6)), + config, + ) + assert caplog.messages[-1].startswith("objects have different data:") + + +@xarray_available +def test_variable_equality_operator_equal_false_dims(config: EqualityConfig) -> None: + assert not XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.arange(6)), + xr.Variable(dims=["x"], data=np.arange(6)), + config, + ) + + +@xarray_available +def test_variable_equality_operator_equal_false_dims_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = XarrayVariableEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + xr.Variable(dims=["z"], data=np.arange(6)), + xr.Variable(dims=["x"], data=np.arange(6)), + config, + ) + assert caplog.messages[-1].startswith("objects have different dims:") + + +@xarray_available +def test_variable_equality_operator_equal_false_different_attrs(config: EqualityConfig) -> None: + assert not XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meow"}), + xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meoowww"}), + config, + ) + + +@xarray_available +def test_variable_equality_operator_equal_false_attrs_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = XarrayVariableEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meow"}), + xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meoowww"}), + config, + ) + assert caplog.messages[-1].startswith("objects have different attrs:") + + +@xarray_available +def test_variable_equality_operator_equal_false_different_type(config: EqualityConfig) -> None: + assert not XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.arange(6)), np.arange(6), config + ) + + +@xarray_available +def test_variable_equality_operator_equal_false_different_type_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = XarrayVariableEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + xr.Variable(dims=["z"], data=np.arange(6)), np.arange(6), config + ) + assert caplog.messages[0].startswith("objects have different types:") + + +@xarray_available +def test_variable_equality_operator_no_xarray() -> None: + with patch( + "coola.utils.imports.is_xarray_available", lambda *args, **kwargs: False + ), pytest.raises(RuntimeError, match="`xarray` package is required but not installed."): + XarrayVariableEqualityComparator() + + +@xarray_available +def test_variable_equality_operator_equal_equal_nan_false(config: EqualityConfig) -> None: + assert not XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), + xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), + config, + ) + + +@xarray_available +def test_variable_equality_operator_equal_equal_nan_true(config: EqualityConfig) -> None: + config.equal_nan = True + # TODO(TIBO): update after the new version is finished # noqa: TD003 + assert not XarrayVariableEqualityComparator().equal( + xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), + xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])), + config, + ) + + +########################################## +# Tests for get_mapping_equality # +########################################## + + +@xarray_available +def test_get_type_comparator_mapping() -> None: + assert get_type_comparator_mapping() == {xr.Variable: XarrayVariableEqualityComparator()} + + +def test_get_type_comparator_mapping_no_xarray() -> None: + with patch( + "coola.equality.comparators.xarray_.is_xarray_available", lambda *args, **kwargs: False + ): + assert get_type_comparator_mapping() == {}