-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
343741c
commit 4119b25
Showing
3 changed files
with
238 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
r"""Implement an equality comparator for ``numpy.ndarray``s.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = ["ArrayEqualityComparator"] | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from coola.equality.comparators.base import BaseEqualityComparator | ||
from coola.equality.handlers import ( | ||
ArraySameDTypeHandler, | ||
ArraySameShapeHandler, | ||
SameObjectHandler, | ||
SameTypeHandler, | ||
) | ||
from coola.equality.handlers.numpy_ import ArrayEqualHandler | ||
from coola.utils import check_numpy | ||
|
||
if TYPE_CHECKING: | ||
from coola.equality import EqualityConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ArrayEqualityComparator(BaseEqualityComparator[Any]): | ||
r"""Implement an equality comparator for ``numpy.ndarray``. | ||
Example usage: | ||
```pycon | ||
>>> import numpy as np | ||
>>> from coola.equality import EqualityConfig | ||
>>> from coola.equality.comparators import ArrayEqualityComparator | ||
>>> from coola.testers import EqualityTester | ||
>>> config = EqualityConfig(tester=EqualityTester()) | ||
>>> comparator = ArrayEqualityComparator() | ||
>>> comparator.equal(np.ones((2, 3)), np.ones((2, 3)), config) | ||
True | ||
>>> comparator.equal(np.ones((2, 3)), np.zeros((2, 3)), config) | ||
False | ||
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
check_numpy() | ||
self._handler = SameObjectHandler() | ||
self._handler.chain(SameTypeHandler()).chain(ArraySameDTypeHandler()).chain( | ||
ArraySameShapeHandler() | ||
).chain(ArrayEqualHandler()) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
return isinstance(other, self.__class__) | ||
|
||
def clone(self) -> ArrayEqualityComparator: | ||
return self.__class__() | ||
|
||
def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool: | ||
return self._handler.handle(object1=object1, object2=object2, config=config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from coola import objects_are_equal | ||
from coola.equality import EqualityConfig | ||
from coola.equality.comparators import ArrayEqualityComparator | ||
from coola.testers import EqualityTester | ||
from coola.testing import numpy_available | ||
from coola.utils.imports import is_numpy_available | ||
|
||
if is_numpy_available(): | ||
import numpy as np | ||
else: | ||
np = Mock() | ||
|
||
|
||
@pytest.fixture() | ||
def config() -> EqualityConfig: | ||
return EqualityConfig(tester=EqualityTester()) | ||
|
||
|
||
############################################# | ||
# Tests for ArrayEqualityComparator # | ||
############################################# | ||
|
||
|
||
@numpy_available | ||
def test_objects_are_equal_array() -> None: | ||
assert objects_are_equal(np.ones((2, 3)), np.ones((2, 3))) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_str() -> None: | ||
assert str(ArrayEqualityComparator()).startswith("ArrayEqualityComparator(") | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator__eq__true() -> None: | ||
assert ArrayEqualityComparator() == ArrayEqualityComparator() | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator__eq__false_different_type() -> None: | ||
assert ArrayEqualityComparator() != 123 | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_clone() -> None: | ||
op = ArrayEqualityComparator() | ||
op_cloned = op.clone() | ||
assert op is not op_cloned | ||
assert op == op_cloned | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_true(config: EqualityConfig) -> None: | ||
assert ArrayEqualityComparator().equal(np.ones((2, 3)), np.ones((2, 3)), config) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None: | ||
array = np.ones((2, 3)) | ||
assert ArrayEqualityComparator().equal(array, array, config) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_true_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = ArrayEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert comparator.equal( | ||
object1=np.ones((2, 3)), | ||
object2=np.ones((2, 3)), | ||
config=config, | ||
) | ||
assert not caplog.messages | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_dtype(config: EqualityConfig) -> None: | ||
assert not ArrayEqualityComparator().equal( | ||
np.ones(shape=(2, 3), dtype=float), np.ones(shape=(2, 3), dtype=int), config | ||
) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_dtype_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = ArrayEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
object1=np.ones(shape=(2, 3), dtype=float), | ||
object2=np.ones(shape=(2, 3), dtype=int), | ||
config=config, | ||
) | ||
assert caplog.messages[0].startswith("objects have different data types:") | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_shape(config: EqualityConfig) -> None: | ||
assert not ArrayEqualityComparator().equal(np.ones((2, 3)), np.zeros((6,)), config) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_shape_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = ArrayEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
object1=np.ones((2, 3)), | ||
object2=np.zeros((6,)), | ||
config=config, | ||
) | ||
assert caplog.messages[0].startswith("objects have different shapes:") | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_value(config: EqualityConfig) -> None: | ||
assert not ArrayEqualityComparator().equal(np.ones((2, 3)), np.zeros((2, 3)), config) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_value_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = ArrayEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
object1=np.ones((2, 3)), | ||
object2=np.zeros((2, 3)), | ||
config=config, | ||
) | ||
assert caplog.messages[0].startswith("numpy.ndarrays have different elements:") | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_type(config: EqualityConfig) -> None: | ||
assert not ArrayEqualityComparator().equal(object1=np.ones((2, 3)), object2=42, config=config) | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_equal_false_different_type_show_difference( | ||
caplog: pytest.LogCaptureFixture, config: EqualityConfig | ||
) -> None: | ||
config.show_difference = True | ||
comparator = ArrayEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal( | ||
object1=np.ones((2, 3)), | ||
object2=42, | ||
config=config, | ||
) | ||
assert caplog.messages[0].startswith("objects have different types:") | ||
|
||
|
||
@numpy_available | ||
def test_array_equality_comparator_no_numpy() -> None: | ||
with patch( | ||
"coola.utils.imports.is_numpy_available", lambda *args, **kwargs: False | ||
), pytest.raises(RuntimeError, match="`numpy` package is required but not installed."): | ||
ArrayEqualityComparator() |