Skip to content

Commit

Permalink
Add ArrayEqualityComparator (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 6, 2024
1 parent 343741c commit 4119b25
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/coola/equality/comparators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from __future__ import annotations

__all__ = ["BaseEqualityComparator", "DefaultEqualityComparator"]
__all__ = [
"ArrayEqualityComparator",
"BaseEqualityComparator",
"DefaultEqualityComparator",
]

from coola.equality.comparators.base import BaseEqualityComparator
from coola.equality.comparators.default import DefaultEqualityComparator
from coola.equality.comparators.numpy_ import ArrayEqualityComparator
60 changes: 60 additions & 0 deletions src/coola/equality/comparators/numpy_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
r"""Implement an equality comparator for ``numpy.ndarray``s."""

from __future__ import annotations

__all__ = ["ArrayEqualityComparator"]

import logging
from typing import TYPE_CHECKING, Any

from coola.equality.comparators.base import BaseEqualityComparator
from coola.equality.handlers import (
ArraySameDTypeHandler,
ArraySameShapeHandler,
SameObjectHandler,
SameTypeHandler,
)
from coola.equality.handlers.numpy_ import ArrayEqualHandler
from coola.utils import check_numpy

if TYPE_CHECKING:
from coola.equality import EqualityConfig

logger = logging.getLogger(__name__)


class ArrayEqualityComparator(BaseEqualityComparator[Any]):
r"""Implement an equality comparator for ``numpy.ndarray``.
Example usage:
```pycon
>>> import numpy as np
>>> from coola.equality import EqualityConfig
>>> from coola.equality.comparators import ArrayEqualityComparator
>>> from coola.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> comparator = ArrayEqualityComparator()
>>> comparator.equal(np.ones((2, 3)), np.ones((2, 3)), config)
True
>>> comparator.equal(np.ones((2, 3)), np.zeros((2, 3)), config)
False
```
"""

def __init__(self) -> None:
check_numpy()
self._handler = SameObjectHandler()
self._handler.chain(SameTypeHandler()).chain(ArraySameDTypeHandler()).chain(
ArraySameShapeHandler()
).chain(ArrayEqualHandler())

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)

def clone(self) -> ArrayEqualityComparator:
return self.__class__()

def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool:
return self._handler.handle(object1=object1, object2=object2, config=config)
172 changes: 172 additions & 0 deletions tests/unit/equality/comparators/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import logging
from unittest.mock import Mock, patch

import pytest

from coola import objects_are_equal
from coola.equality import EqualityConfig
from coola.equality.comparators import ArrayEqualityComparator
from coola.testers import EqualityTester
from coola.testing import numpy_available
from coola.utils.imports import is_numpy_available

if is_numpy_available():
import numpy as np
else:
np = Mock()


@pytest.fixture()
def config() -> EqualityConfig:
return EqualityConfig(tester=EqualityTester())


#############################################
# Tests for ArrayEqualityComparator #
#############################################


@numpy_available
def test_objects_are_equal_array() -> None:
assert objects_are_equal(np.ones((2, 3)), np.ones((2, 3)))


@numpy_available
def test_array_equality_comparator_str() -> None:
assert str(ArrayEqualityComparator()).startswith("ArrayEqualityComparator(")


@numpy_available
def test_array_equality_comparator__eq__true() -> None:
assert ArrayEqualityComparator() == ArrayEqualityComparator()


@numpy_available
def test_array_equality_comparator__eq__false_different_type() -> None:
assert ArrayEqualityComparator() != 123


@numpy_available
def test_array_equality_comparator_clone() -> None:
op = ArrayEqualityComparator()
op_cloned = op.clone()
assert op is not op_cloned
assert op == op_cloned


@numpy_available
def test_array_equality_comparator_equal_true(config: EqualityConfig) -> None:
assert ArrayEqualityComparator().equal(np.ones((2, 3)), np.ones((2, 3)), config)


@numpy_available
def test_array_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None:
array = np.ones((2, 3))
assert ArrayEqualityComparator().equal(array, array, config)


@numpy_available
def test_array_equality_comparator_equal_true_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = ArrayEqualityComparator()
with caplog.at_level(logging.INFO):
assert comparator.equal(
object1=np.ones((2, 3)),
object2=np.ones((2, 3)),
config=config,
)
assert not caplog.messages


@numpy_available
def test_array_equality_comparator_equal_false_different_dtype(config: EqualityConfig) -> None:
assert not ArrayEqualityComparator().equal(
np.ones(shape=(2, 3), dtype=float), np.ones(shape=(2, 3), dtype=int), config
)


@numpy_available
def test_array_equality_comparator_equal_false_different_dtype_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = ArrayEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
object1=np.ones(shape=(2, 3), dtype=float),
object2=np.ones(shape=(2, 3), dtype=int),
config=config,
)
assert caplog.messages[0].startswith("objects have different data types:")


@numpy_available
def test_array_equality_comparator_equal_false_different_shape(config: EqualityConfig) -> None:
assert not ArrayEqualityComparator().equal(np.ones((2, 3)), np.zeros((6,)), config)


@numpy_available
def test_array_equality_comparator_equal_false_different_shape_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = ArrayEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
object1=np.ones((2, 3)),
object2=np.zeros((6,)),
config=config,
)
assert caplog.messages[0].startswith("objects have different shapes:")


@numpy_available
def test_array_equality_comparator_equal_false_different_value(config: EqualityConfig) -> None:
assert not ArrayEqualityComparator().equal(np.ones((2, 3)), np.zeros((2, 3)), config)


@numpy_available
def test_array_equality_comparator_equal_false_different_value_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = ArrayEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
object1=np.ones((2, 3)),
object2=np.zeros((2, 3)),
config=config,
)
assert caplog.messages[0].startswith("numpy.ndarrays have different elements:")


@numpy_available
def test_array_equality_comparator_equal_false_different_type(config: EqualityConfig) -> None:
assert not ArrayEqualityComparator().equal(object1=np.ones((2, 3)), object2=42, config=config)


@numpy_available
def test_array_equality_comparator_equal_false_different_type_show_difference(
caplog: pytest.LogCaptureFixture, config: EqualityConfig
) -> None:
config.show_difference = True
comparator = ArrayEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(
object1=np.ones((2, 3)),
object2=42,
config=config,
)
assert caplog.messages[0].startswith("objects have different types:")


@numpy_available
def test_array_equality_comparator_no_numpy() -> None:
with patch(
"coola.utils.imports.is_numpy_available", lambda *args, **kwargs: False
), pytest.raises(RuntimeError, match="`numpy` package is required but not installed."):
ArrayEqualityComparator()

0 comments on commit 4119b25

Please sign in to comment.