Skip to content

Commit

Permalink
Add XarrayVariableEqualityComparator
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Jan 7, 2024
1 parent bb31843 commit 7b59e3c
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/coola/equality/comparators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"SequenceEqualityComparator",
"TorchPackedSequenceEqualityComparator",
"TorchTensorEqualityComparator",
"XarrayVariableEqualityComparator",
"get_type_comparator_mapping",
]

Expand All @@ -27,3 +28,4 @@
TorchTensorEqualityComparator,
)
from coola.equality.comparators.utils import get_type_comparator_mapping
from coola.equality.comparators.xarray_ import XarrayVariableEqualityComparator
2 changes: 1 addition & 1 deletion src/coola/equality/comparators/jax_.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
r"""Get a default mapping between the types and the equality
comparators.
This function returns an empty dictionary if jax is not
This function returns an empty dictionary if ``jax`` is not
installed.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion src/coola/equality/comparators/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
r"""Get a default mapping between the types and the equality
comparators.
This function returns an empty dictionary if numpy is not
This function returns an empty dictionary if ``numpy`` is not
installed.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion src/coola/equality/comparators/torch_.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
r"""Get a default mapping between the types and the equality
comparators.
This function returns an empty dictionary if torch is not
This function returns an empty dictionary if ``torch`` is not
installed.
Returns:
Expand Down
99 changes: 99 additions & 0 deletions src/coola/equality/comparators/xarray_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
r"""Implement an equality comparator for ``xarray`` objects."""

from __future__ import annotations

__all__ = ["XarrayVariableEqualityComparator", "get_type_comparator_mapping"]

import logging
from typing import TYPE_CHECKING, Any
from unittest.mock import Mock

from coola.equality.comparators.base import BaseEqualityComparator
from coola.equality.handlers import (
SameAttributeHandler,
SameDataHandler,
SameObjectHandler,
SameTypeHandler,
TrueHandler,
)
from coola.utils import check_xarray, is_xarray_available

if is_xarray_available():
import xarray as xr
else: # pragma: no cover
xr = Mock()

if TYPE_CHECKING:
from coola.equality import EqualityConfig

logger = logging.getLogger(__name__)


class XarrayVariableEqualityComparator(BaseEqualityComparator[xr.Variable]):
r"""Implement an equality comparator for ``xarray.Variable``.
Example usage:
```pycon
>>> import numpy as np
>>> import xarray as xr
>>> from coola.equality import EqualityConfig
>>> from coola.equality.comparators import NumpyArrayEqualityComparator
>>> from coola.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> comparator = NumpyArrayEqualityComparator()
>>> comparator.equal(
... xr.Variable(dims=["z"], data=np.arange(6)),
... xr.Variable(dims=["z"], data=np.arange(6)),
... config,
... )
True
>>> comparator.equal(
... xr.Variable(dims=["z"], data=np.zeros(6)),
... xr.Variable(dims=["z"], data=np.ones(6)),
... config,
... )
False
```
"""

def __init__(self) -> None:
check_xarray()
self._handler = SameObjectHandler()
self._handler.chain(SameTypeHandler()).chain(SameDataHandler()).chain(
SameAttributeHandler(name="dims")
).chain(SameAttributeHandler(name="attrs")).chain(TrueHandler())

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)

def clone(self) -> XarrayVariableEqualityComparator:
return self.__class__()

def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool:
return self._handler.handle(object1=object1, object2=object2, config=config)


def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
r"""Get a default mapping between the types and the equality
comparators.
This function returns an empty dictionary if ``xarray`` is not
installed.
Returns:
The mapping between the types and the equality comparators.
Example usage:
```pycon
>>> from coola.equality.comparators.xarray_ import get_type_comparator_mapping
>>> get_type_comparator_mapping()
{<class 'xarray.core.variable.Variable'>: XarrayVariableEqualityComparator()}
```
"""
if not is_xarray_available():
return {}
return {xr.Variable: XarrayVariableEqualityComparator()}
233 changes: 233 additions & 0 deletions tests/unit/equality/comparators/test_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from __future__ import annotations

import logging
from unittest.mock import Mock, patch

import pytest

from coola import objects_are_equal
from coola.equality import EqualityConfig
from coola.equality.comparators.xarray_ import (
XarrayVariableEqualityComparator,
get_type_comparator_mapping,
)
from coola.testers import EqualityTester
from coola.testing import xarray_available
from coola.utils.imports import is_numpy_available, is_xarray_available

if is_numpy_available():
import numpy as np
else:
np = Mock()

if is_xarray_available():
import xarray as xr
else:
xr = Mock()


@pytest.fixture()
def config() -> EqualityConfig:
return EqualityConfig(tester=EqualityTester())


######################################################
# Tests for XarrayVariableEqualityComparator #
######################################################


@xarray_available
def test_objects_are_equal_variable() -> None:
assert objects_are_equal(
xr.Variable(dims=["z"], data=np.arange(6)), xr.Variable(dims=["z"], data=np.arange(6))
)


@xarray_available
def test_variable_equality_operator_str() -> None:
assert str(XarrayVariableEqualityComparator()).startswith("XarrayVariableEqualityComparator(")


@xarray_available
def test_variable_equality_operator__eq__true() -> None:
assert XarrayVariableEqualityComparator() == XarrayVariableEqualityComparator()


@xarray_available
def test_variable_equality_operator__eq__false() -> None:
assert XarrayVariableEqualityComparator() != 123


@xarray_available
def test_variable_equality_operator_clone() -> None:
op = XarrayVariableEqualityComparator()
op_cloned = op.clone()
assert op is not op_cloned
assert op == op_cloned


@xarray_available
def test_variable_equality_operator_equal_true(config: EqualityConfig) -> None:
assert XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.arange(6)),
xr.Variable(dims=["z"], data=np.arange(6)),
config,
)


@xarray_available
def test_variable_equality_operator_equal_true_same_object(config: EqualityConfig) -> None:
obj = xr.Variable(dims=["z"], data=np.arange(6))
assert XarrayVariableEqualityComparator().equal(obj, obj, config)


@xarray_available
def test_variable_equality_operator_equal_true_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = XarrayVariableEqualityComparator()
with caplog.at_level(logging.INFO):
assert comparator.equal(
xr.Variable(dims=["z"], data=np.arange(6)),
xr.Variable(dims=["z"], data=np.arange(6)),
config,
)
assert not caplog.messages


@xarray_available
def test_variable_equality_operator_equal_false_data(config: EqualityConfig) -> None:
assert not XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.ones(6)),
xr.Variable(dims=["z"], data=np.zeros(6)),
config,
)


@xarray_available
def test_variable_equality_operator_equal_false_data_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = XarrayVariableEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
xr.Variable(dims=["z"], data=np.ones(6)),
xr.Variable(dims=["z"], data=np.zeros(6)),
config,
)
assert caplog.messages[-1].startswith("objects have different data:")


@xarray_available
def test_variable_equality_operator_equal_false_dims(config: EqualityConfig) -> None:
assert not XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.arange(6)),
xr.Variable(dims=["x"], data=np.arange(6)),
config,
)


@xarray_available
def test_variable_equality_operator_equal_false_dims_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = XarrayVariableEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
xr.Variable(dims=["z"], data=np.arange(6)),
xr.Variable(dims=["x"], data=np.arange(6)),
config,
)
assert caplog.messages[-1].startswith("objects have different dims:")


@xarray_available
def test_variable_equality_operator_equal_false_different_attrs(config: EqualityConfig) -> None:
assert not XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meow"}),
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meoowww"}),
config,
)


@xarray_available
def test_variable_equality_operator_equal_false_attrs_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = XarrayVariableEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meow"}),
xr.Variable(dims=["z"], data=np.arange(6), attrs={"global": "meoowww"}),
config,
)
assert caplog.messages[-1].startswith("objects have different attrs:")


@xarray_available
def test_variable_equality_operator_equal_false_different_type(config: EqualityConfig) -> None:
assert not XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.arange(6)), np.arange(6), config
)


@xarray_available
def test_variable_equality_operator_equal_false_different_type_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = XarrayVariableEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
xr.Variable(dims=["z"], data=np.arange(6)), np.arange(6), config
)
assert caplog.messages[0].startswith("objects have different types:")


@xarray_available
def test_variable_equality_operator_no_xarray() -> None:
with patch(
"coola.utils.imports.is_xarray_available", lambda *args, **kwargs: False
), pytest.raises(RuntimeError, match="`xarray` package is required but not installed."):
XarrayVariableEqualityComparator()


@xarray_available
def test_variable_equality_operator_equal_equal_nan_false(config: EqualityConfig) -> None:
assert not XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])),
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])),
config,
)


@xarray_available
def test_variable_equality_operator_equal_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
# TODO(TIBO): update after the new version is finished # noqa: TD003
assert not XarrayVariableEqualityComparator().equal(
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])),
xr.Variable(dims=["z"], data=np.array([0.0, float("nan"), 2.0])),
config,
)


##########################################
# Tests for get_mapping_equality #
##########################################


@xarray_available
def test_get_type_comparator_mapping() -> None:
assert get_type_comparator_mapping() == {xr.Variable: XarrayVariableEqualityComparator()}


def test_get_type_comparator_mapping_no_xarray() -> None:
with patch(
"coola.equality.comparators.xarray_.is_xarray_available", lambda *args, **kwargs: False
):
assert get_type_comparator_mapping() == {}

0 comments on commit 7b59e3c

Please sign in to comment.