Skip to content

Commit

Permalink
Add FloatEqualHandler (#416)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 10, 2024
1 parent a3de571 commit 62e9e5f
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 34 deletions.
2 changes: 2 additions & 0 deletions src/coola/equality/comparators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = [
"BaseEqualityComparator",
"DefaultEqualityComparator",
"FloatEqualityComparator",
"JaxArrayEqualityComparator",
"MappingEqualityComparator",
"NumpyArrayEqualityComparator",
Expand Down Expand Up @@ -41,6 +42,7 @@
PolarsDataFrameEqualityComparator,
PolarsSeriesEqualityComparator,
)
from coola.equality.comparators.scalar import FloatEqualityComparator
from coola.equality.comparators.torch_ import (
TorchPackedSequenceEqualityComparator,
TorchTensorEqualityComparator,
Expand Down
71 changes: 71 additions & 0 deletions src/coola/equality/comparators/scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
r"""Implement scalar equality comparators."""

from __future__ import annotations

__all__ = ["FloatEqualityComparator", "get_type_comparator_mapping"]

import logging
from typing import TYPE_CHECKING, Any

from coola.equality.comparators.base import BaseEqualityComparator
from coola.equality.handlers import (
FloatEqualHandler,
SameObjectHandler,
SameTypeHandler,
)

if TYPE_CHECKING:
from coola.equality import EqualityConfig

logger = logging.getLogger(__name__)


class FloatEqualityComparator(BaseEqualityComparator[Any]):
r"""Implement a default equality comparator.
Example usage:
```pycon
>>> from coola.equality import EqualityConfig
>>> from coola.equality.comparators import FloatEqualityComparator
>>> from coola.equality.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> comparator = FloatEqualityComparator()
>>> comparator.equal(42.0, 42.0, config)
True
>>> comparator.equal(42.0, 1.0, config)
False
```
"""

def __init__(self) -> None:
self._handler = SameObjectHandler()
self._handler.chain(SameTypeHandler()).chain(FloatEqualHandler())

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)

def clone(self) -> FloatEqualityComparator:
return self.__class__()

def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool:
return self._handler.handle(object1=object1, object2=object2, config=config)


def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
r"""Get a mapping between the types and the equality comparators.
Returns:
The mapping between the types and the equality comparators.
Example usage:
```pycon
>>> from coola.equality.comparators.scalar import get_type_comparator_mapping
>>> get_type_comparator_mapping()
{<class 'float'>: FloatEqualityComparator()}
```
"""
return {float: FloatEqualityComparator()}
19 changes: 2 additions & 17 deletions src/coola/equality/comparators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
>>> get_type_comparator_mapping()
{<class 'object'>: DefaultEqualityComparator(),
<class 'collections.abc.Mapping'>: MappingEqualityComparator(),
<class 'collections.abc.Sequence'>: SequenceEqualityComparator(),
<class 'dict'>: MappingEqualityComparator(),
<class 'list'>: SequenceEqualityComparator(),
<class 'tuple'>: SequenceEqualityComparator(),
<class 'jax.Array'>: JaxArrayEqualityComparator(),
<class 'jaxlib.xla_extension.ArrayImpl'>: JaxArrayEqualityComparator(),
<class 'numpy.ndarray'>: NumpyArrayEqualityComparator(),
<class 'numpy.ma...MaskedArray'>: NumpyMaskedArrayEqualityComparator(),
<class 'pandas...DataFrame'>: PandasDataFrameEqualityComparator(),
<class 'pandas...Series'>: PandasSeriesEqualityComparator(),
<class 'polars...DataFrame'>: PolarsDataFrameEqualityComparator(),
<class 'polars...Series'>: PolarsSeriesEqualityComparator(),
<class 'torch.nn.utils.rnn.PackedSequence'>: TorchPackedSequenceEqualityComparator(),
<class 'torch.Tensor'>: TorchTensorEqualityComparator(),
<class 'xarray...DataArray'>: XarrayDataArrayEqualityComparator(),
<class 'xarray...Dataset'>: XarrayDatasetEqualityComparator(),
<class 'xarray...Variable'>: XarrayVariableEqualityComparator()}
<class 'collections.abc.Sequence'>: SequenceEqualityComparator(), ...}
```
"""
Expand All @@ -52,6 +36,7 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
| comparators.numpy_.get_type_comparator_mapping()
| comparators.pandas_.get_type_comparator_mapping()
| comparators.polars_.get_type_comparator_mapping()
| comparators.scalar.get_type_comparator_mapping()
| comparators.torch_.get_type_comparator_mapping()
| comparators.xarray_.get_type_comparator_mapping()
)
17 changes: 1 addition & 16 deletions src/coola/equality/testers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,22 +280,7 @@ def register_equality() -> None:
(<class 'object'>): DefaultEqualityComparator()
(<class 'collections.abc.Mapping'>): MappingEqualityComparator()
(<class 'collections.abc.Sequence'>): SequenceEqualityComparator()
(<class 'dict'>): MappingEqualityComparator()
(<class 'list'>): SequenceEqualityComparator()
(<class 'tuple'>): SequenceEqualityComparator()
(<class 'jax.Array'>): JaxArrayEqualityComparator()
(<class 'jaxlib.xla_extension.ArrayImpl'>): JaxArrayEqualityComparator()
(<class 'numpy.ndarray'>): NumpyArrayEqualityComparator()
(<class 'numpy.ma...MaskedArray'>): NumpyMaskedArrayEqualityComparator()
(<class 'pandas...DataFrame'>): PandasDataFrameEqualityComparator()
(<class 'pandas...Series'>): PandasSeriesEqualityComparator()
(<class 'polars...DataFrame'>): PolarsDataFrameEqualityComparator()
(<class 'polars...Series'>): PolarsSeriesEqualityComparator()
(<class 'torch.nn.utils.rnn.PackedSequence'>): TorchPackedSequenceEqualityComparator()
(<class 'torch.Tensor'>): TorchTensorEqualityComparator()
(<class 'xarray...DataArray'>): XarrayDataArrayEqualityComparator()
(<class 'xarray...Dataset'>): XarrayDatasetEqualityComparator()
(<class 'xarray...Variable'>): XarrayVariableEqualityComparator()
...
)
```
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/equality/checks/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import pytest

from coola import objects_are_equal
from tests.unit.equality.comparators.test_scalar import FLOAT_EQUAL, FLOAT_NOT_EQUAL

if TYPE_CHECKING:
from tests.unit.equality.comparators.utils import ExamplePair


@pytest.mark.parametrize("example", FLOAT_EQUAL)
@pytest.mark.parametrize("show_difference", [True, False])
def test_objects_are_equal_true(
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
assert objects_are_equal(example.object1, example.object2, show_difference)
assert not caplog.messages


@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL)
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2)
assert not caplog.messages


@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL)
def test_objects_are_equal_false_show_difference(
example: ExamplePair, caplog: pytest.LogCaptureFixture
) -> None:
with caplog.at_level(logging.INFO):
assert not objects_are_equal(example.object1, example.object2, show_difference=True)
assert caplog.messages[-1].startswith(example.expected_message)
142 changes: 142 additions & 0 deletions tests/unit/equality/comparators/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from __future__ import annotations

import logging

import pytest

from coola.equality import EqualityConfig
from coola.equality.comparators.scalar import (
FloatEqualityComparator,
get_type_comparator_mapping,
)
from coola.equality.testers import EqualityTester
from tests.unit.equality.comparators.utils import ExamplePair


@pytest.fixture()
def config() -> EqualityConfig:
return EqualityConfig(tester=EqualityTester())


FLOAT_EQUAL = [
pytest.param(ExamplePair(object1=4.2, object2=4.2), id="positive"),
pytest.param(ExamplePair(object1=0.0, object2=0.0), id="zero"),
pytest.param(ExamplePair(object1=-4.2, object2=-4.2), id="negative"),
pytest.param(ExamplePair(object1=float("inf"), object2=float("inf")), id="infinity"),
pytest.param(ExamplePair(object1=float("-inf"), object2=float("-inf")), id="-infinity"),
]


FLOAT_NOT_EQUAL = [
pytest.param(
ExamplePair(object1=4.2, object2=1.0, expected_message="numbers are not equal:"),
id="different values",
),
pytest.param(
ExamplePair(object1=4.2, object2="meow", expected_message="objects have different types:"),
id="different types",
),
]


#############################################
# Tests for FloatEqualityComparator #
#############################################


def test_float_equality_comparator_str() -> None:
assert str(FloatEqualityComparator()).startswith("FloatEqualityComparator(")


def test_float_equality_comparator__eq__true() -> None:
assert FloatEqualityComparator() == FloatEqualityComparator()


def test_float_equality_comparator__eq__false_different_type() -> None:
assert FloatEqualityComparator() != 123


def test_float_equality_comparator_clone() -> None:
op = FloatEqualityComparator()
op_cloned = op.clone()
assert op is not op_cloned
assert op == op_cloned


def test_float_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None:
x = 4.2
assert FloatEqualityComparator().equal(x, x, config)


@pytest.mark.parametrize("example", FLOAT_EQUAL)
def test_float_equality_comparator_equal_yes(
example: ExamplePair,
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
comparator = FloatEqualityComparator()
with caplog.at_level(logging.INFO):
assert comparator.equal(object1=example.object1, object2=example.object2, config=config)
assert not caplog.messages


@pytest.mark.parametrize("example", FLOAT_EQUAL)
def test_float_equality_comparator_equal_yes_show_difference(
example: ExamplePair,
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
config.show_difference = True
comparator = FloatEqualityComparator()
with caplog.at_level(logging.INFO):
assert comparator.equal(object1=example.object1, object2=example.object2, config=config)
assert not caplog.messages


@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL)
def test_float_equality_comparator_equal_false(
example: ExamplePair,
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
comparator = FloatEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(object1=example.object1, object2=example.object2, config=config)
assert not caplog.messages


@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL)
def test_float_equality_comparator_equal_false_show_difference(
example: ExamplePair,
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
config.show_difference = True
comparator = FloatEqualityComparator()
with caplog.at_level(logging.INFO):
assert not comparator.equal(object1=example.object1, object2=example.object2, config=config)
assert caplog.messages[-1].startswith(example.expected_message)


@pytest.mark.parametrize("equal_nan", [False, True])
def test_float_equality_comparator_equal_nan(config: EqualityConfig, equal_nan: bool) -> None:
config.equal_nan = equal_nan
assert (
FloatEqualityComparator().equal(
object1=float("nan"),
object2=float("nan"),
config=config,
)
== equal_nan
)


#################################################
# Tests for get_type_comparator_mapping #
#################################################


def test_get_type_comparator_mapping() -> None:
assert get_type_comparator_mapping() == {
float: FloatEqualityComparator(),
}
4 changes: 3 additions & 1 deletion tests/unit/equality/comparators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from coola.equality.comparators import (
DefaultEqualityComparator,
FloatEqualityComparator,
JaxArrayEqualityComparator,
MappingEqualityComparator,
NumpyArrayEqualityComparator,
Expand Down Expand Up @@ -57,10 +58,11 @@

def test_get_type_comparator_mapping() -> None:
mapping = get_type_comparator_mapping()
assert len(mapping) >= 6
assert len(mapping) >= 7
assert isinstance(mapping[Mapping], MappingEqualityComparator)
assert isinstance(mapping[Sequence], SequenceEqualityComparator)
assert isinstance(mapping[dict], MappingEqualityComparator)
assert isinstance(mapping[float], FloatEqualityComparator)
assert isinstance(mapping[list], SequenceEqualityComparator)
assert isinstance(mapping[object], DefaultEqualityComparator)
assert isinstance(mapping[tuple], SequenceEqualityComparator)
Expand Down

0 comments on commit 62e9e5f

Please sign in to comment.