diff --git a/src/coola/equality/comparators/__init__.py b/src/coola/equality/comparators/__init__.py index 93c9d73c..99783156 100644 --- a/src/coola/equality/comparators/__init__.py +++ b/src/coola/equality/comparators/__init__.py @@ -5,6 +5,7 @@ __all__ = [ "BaseEqualityComparator", "DefaultEqualityComparator", + "FloatEqualityComparator", "JaxArrayEqualityComparator", "MappingEqualityComparator", "NumpyArrayEqualityComparator", @@ -41,6 +42,7 @@ PolarsDataFrameEqualityComparator, PolarsSeriesEqualityComparator, ) +from coola.equality.comparators.scalar import FloatEqualityComparator from coola.equality.comparators.torch_ import ( TorchPackedSequenceEqualityComparator, TorchTensorEqualityComparator, diff --git a/src/coola/equality/comparators/scalar.py b/src/coola/equality/comparators/scalar.py new file mode 100644 index 00000000..ddb3010f --- /dev/null +++ b/src/coola/equality/comparators/scalar.py @@ -0,0 +1,71 @@ +r"""Implement scalar equality comparators.""" + +from __future__ import annotations + +__all__ = ["FloatEqualityComparator", "get_type_comparator_mapping"] + +import logging +from typing import TYPE_CHECKING, Any + +from coola.equality.comparators.base import BaseEqualityComparator +from coola.equality.handlers import ( + FloatEqualHandler, + SameObjectHandler, + SameTypeHandler, +) + +if TYPE_CHECKING: + from coola.equality import EqualityConfig + +logger = logging.getLogger(__name__) + + +class FloatEqualityComparator(BaseEqualityComparator[Any]): + r"""Implement a default equality comparator. + + Example usage: + + ```pycon + >>> from coola.equality import EqualityConfig + >>> from coola.equality.comparators import FloatEqualityComparator + >>> from coola.equality.testers import EqualityTester + >>> config = EqualityConfig(tester=EqualityTester()) + >>> comparator = FloatEqualityComparator() + >>> comparator.equal(42.0, 42.0, config) + True + >>> comparator.equal(42.0, 1.0, config) + False + + ``` + """ + + def __init__(self) -> None: + self._handler = SameObjectHandler() + self._handler.chain(SameTypeHandler()).chain(FloatEqualHandler()) + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def clone(self) -> FloatEqualityComparator: + return self.__class__() + + def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool: + return self._handler.handle(object1=object1, object2=object2, config=config) + + +def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: + r"""Get a mapping between the types and the equality comparators. + + Returns: + The mapping between the types and the equality comparators. + + Example usage: + + ```pycon + >>> from coola.equality.comparators.scalar import get_type_comparator_mapping + >>> get_type_comparator_mapping() + {: FloatEqualityComparator()} + + ``` + """ + return {float: FloatEqualityComparator()} diff --git a/src/coola/equality/comparators/utils.py b/src/coola/equality/comparators/utils.py index 07cecd5a..5ee56a24 100644 --- a/src/coola/equality/comparators/utils.py +++ b/src/coola/equality/comparators/utils.py @@ -23,23 +23,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: >>> get_type_comparator_mapping() {: DefaultEqualityComparator(), : MappingEqualityComparator(), - : SequenceEqualityComparator(), - : MappingEqualityComparator(), - : SequenceEqualityComparator(), - : SequenceEqualityComparator(), - : JaxArrayEqualityComparator(), - : JaxArrayEqualityComparator(), - : NumpyArrayEqualityComparator(), - : NumpyMaskedArrayEqualityComparator(), - : PandasDataFrameEqualityComparator(), - : PandasSeriesEqualityComparator(), - : PolarsDataFrameEqualityComparator(), - : PolarsSeriesEqualityComparator(), - : TorchPackedSequenceEqualityComparator(), - : TorchTensorEqualityComparator(), - : XarrayDataArrayEqualityComparator(), - : XarrayDatasetEqualityComparator(), - : XarrayVariableEqualityComparator()} + : SequenceEqualityComparator(), ...} ``` """ @@ -52,6 +36,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: | comparators.numpy_.get_type_comparator_mapping() | comparators.pandas_.get_type_comparator_mapping() | comparators.polars_.get_type_comparator_mapping() + | comparators.scalar.get_type_comparator_mapping() | comparators.torch_.get_type_comparator_mapping() | comparators.xarray_.get_type_comparator_mapping() ) diff --git a/src/coola/equality/testers/default.py b/src/coola/equality/testers/default.py index 98aede25..1d3a93d4 100644 --- a/src/coola/equality/testers/default.py +++ b/src/coola/equality/testers/default.py @@ -280,22 +280,7 @@ def register_equality() -> None: (): DefaultEqualityComparator() (): MappingEqualityComparator() (): SequenceEqualityComparator() - (): MappingEqualityComparator() - (): SequenceEqualityComparator() - (): SequenceEqualityComparator() - (): JaxArrayEqualityComparator() - (): JaxArrayEqualityComparator() - (): NumpyArrayEqualityComparator() - (): NumpyMaskedArrayEqualityComparator() - (): PandasDataFrameEqualityComparator() - (): PandasSeriesEqualityComparator() - (): PolarsDataFrameEqualityComparator() - (): PolarsSeriesEqualityComparator() - (): TorchPackedSequenceEqualityComparator() - (): TorchTensorEqualityComparator() - (): XarrayDataArrayEqualityComparator() - (): XarrayDatasetEqualityComparator() - (): XarrayVariableEqualityComparator() + ... ) ``` diff --git a/tests/unit/equality/checks/test_scalar.py b/tests/unit/equality/checks/test_scalar.py new file mode 100644 index 00000000..a38f2afe --- /dev/null +++ b/tests/unit/equality/checks/test_scalar.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import pytest + +from coola import objects_are_equal +from tests.unit.equality.comparators.test_scalar import FLOAT_EQUAL, FLOAT_NOT_EQUAL + +if TYPE_CHECKING: + from tests.unit.equality.comparators.utils import ExamplePair + + +@pytest.mark.parametrize("example", FLOAT_EQUAL) +@pytest.mark.parametrize("show_difference", [True, False]) +def test_objects_are_equal_true( + example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture +) -> None: + with caplog.at_level(logging.INFO): + assert objects_are_equal(example.object1, example.object2, show_difference) + assert not caplog.messages + + +@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) +def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.INFO): + assert not objects_are_equal(example.object1, example.object2) + assert not caplog.messages + + +@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) +def test_objects_are_equal_false_show_difference( + example: ExamplePair, caplog: pytest.LogCaptureFixture +) -> None: + with caplog.at_level(logging.INFO): + assert not objects_are_equal(example.object1, example.object2, show_difference=True) + assert caplog.messages[-1].startswith(example.expected_message) diff --git a/tests/unit/equality/comparators/test_scalar.py b/tests/unit/equality/comparators/test_scalar.py new file mode 100644 index 00000000..83fc7dc9 --- /dev/null +++ b/tests/unit/equality/comparators/test_scalar.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import logging + +import pytest + +from coola.equality import EqualityConfig +from coola.equality.comparators.scalar import ( + FloatEqualityComparator, + get_type_comparator_mapping, +) +from coola.equality.testers import EqualityTester +from tests.unit.equality.comparators.utils import ExamplePair + + +@pytest.fixture() +def config() -> EqualityConfig: + return EqualityConfig(tester=EqualityTester()) + + +FLOAT_EQUAL = [ + pytest.param(ExamplePair(object1=4.2, object2=4.2), id="positive"), + pytest.param(ExamplePair(object1=0.0, object2=0.0), id="zero"), + pytest.param(ExamplePair(object1=-4.2, object2=-4.2), id="negative"), + pytest.param(ExamplePair(object1=float("inf"), object2=float("inf")), id="infinity"), + pytest.param(ExamplePair(object1=float("-inf"), object2=float("-inf")), id="-infinity"), +] + + +FLOAT_NOT_EQUAL = [ + pytest.param( + ExamplePair(object1=4.2, object2=1.0, expected_message="numbers are not equal:"), + id="different values", + ), + pytest.param( + ExamplePair(object1=4.2, object2="meow", expected_message="objects have different types:"), + id="different types", + ), +] + + +############################################# +# Tests for FloatEqualityComparator # +############################################# + + +def test_float_equality_comparator_str() -> None: + assert str(FloatEqualityComparator()).startswith("FloatEqualityComparator(") + + +def test_float_equality_comparator__eq__true() -> None: + assert FloatEqualityComparator() == FloatEqualityComparator() + + +def test_float_equality_comparator__eq__false_different_type() -> None: + assert FloatEqualityComparator() != 123 + + +def test_float_equality_comparator_clone() -> None: + op = FloatEqualityComparator() + op_cloned = op.clone() + assert op is not op_cloned + assert op == op_cloned + + +def test_float_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None: + x = 4.2 + assert FloatEqualityComparator().equal(x, x, config) + + +@pytest.mark.parametrize("example", FLOAT_EQUAL) +def test_float_equality_comparator_equal_yes( + example: ExamplePair, + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + comparator = FloatEqualityComparator() + with caplog.at_level(logging.INFO): + assert comparator.equal(object1=example.object1, object2=example.object2, config=config) + assert not caplog.messages + + +@pytest.mark.parametrize("example", FLOAT_EQUAL) +def test_float_equality_comparator_equal_yes_show_difference( + example: ExamplePair, + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + config.show_difference = True + comparator = FloatEqualityComparator() + with caplog.at_level(logging.INFO): + assert comparator.equal(object1=example.object1, object2=example.object2, config=config) + assert not caplog.messages + + +@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) +def test_float_equality_comparator_equal_false( + example: ExamplePair, + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + comparator = FloatEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal(object1=example.object1, object2=example.object2, config=config) + assert not caplog.messages + + +@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) +def test_float_equality_comparator_equal_false_show_difference( + example: ExamplePair, + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + config.show_difference = True + comparator = FloatEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal(object1=example.object1, object2=example.object2, config=config) + assert caplog.messages[-1].startswith(example.expected_message) + + +@pytest.mark.parametrize("equal_nan", [False, True]) +def test_float_equality_comparator_equal_nan(config: EqualityConfig, equal_nan: bool) -> None: + config.equal_nan = equal_nan + assert ( + FloatEqualityComparator().equal( + object1=float("nan"), + object2=float("nan"), + config=config, + ) + == equal_nan + ) + + +################################################# +# Tests for get_type_comparator_mapping # +################################################# + + +def test_get_type_comparator_mapping() -> None: + assert get_type_comparator_mapping() == { + float: FloatEqualityComparator(), + } diff --git a/tests/unit/equality/comparators/test_utils.py b/tests/unit/equality/comparators/test_utils.py index edf32148..02aaad81 100644 --- a/tests/unit/equality/comparators/test_utils.py +++ b/tests/unit/equality/comparators/test_utils.py @@ -4,6 +4,7 @@ from coola.equality.comparators import ( DefaultEqualityComparator, + FloatEqualityComparator, JaxArrayEqualityComparator, MappingEqualityComparator, NumpyArrayEqualityComparator, @@ -57,10 +58,11 @@ def test_get_type_comparator_mapping() -> None: mapping = get_type_comparator_mapping() - assert len(mapping) >= 6 + assert len(mapping) >= 7 assert isinstance(mapping[Mapping], MappingEqualityComparator) assert isinstance(mapping[Sequence], SequenceEqualityComparator) assert isinstance(mapping[dict], MappingEqualityComparator) + assert isinstance(mapping[float], FloatEqualityComparator) assert isinstance(mapping[list], SequenceEqualityComparator) assert isinstance(mapping[object], DefaultEqualityComparator) assert isinstance(mapping[tuple], SequenceEqualityComparator)