-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a3de571
commit 62e9e5f
Showing
7 changed files
with
259 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
r"""Implement scalar equality comparators.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = ["FloatEqualityComparator", "get_type_comparator_mapping"] | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from coola.equality.comparators.base import BaseEqualityComparator | ||
from coola.equality.handlers import ( | ||
FloatEqualHandler, | ||
SameObjectHandler, | ||
SameTypeHandler, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from coola.equality import EqualityConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FloatEqualityComparator(BaseEqualityComparator[Any]): | ||
r"""Implement a default equality comparator. | ||
Example usage: | ||
```pycon | ||
>>> from coola.equality import EqualityConfig | ||
>>> from coola.equality.comparators import FloatEqualityComparator | ||
>>> from coola.equality.testers import EqualityTester | ||
>>> config = EqualityConfig(tester=EqualityTester()) | ||
>>> comparator = FloatEqualityComparator() | ||
>>> comparator.equal(42.0, 42.0, config) | ||
True | ||
>>> comparator.equal(42.0, 1.0, config) | ||
False | ||
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._handler = SameObjectHandler() | ||
self._handler.chain(SameTypeHandler()).chain(FloatEqualHandler()) | ||
|
||
def __eq__(self, other: object) -> bool: | ||
return isinstance(other, self.__class__) | ||
|
||
def clone(self) -> FloatEqualityComparator: | ||
return self.__class__() | ||
|
||
def equal(self, object1: Any, object2: Any, config: EqualityConfig) -> bool: | ||
return self._handler.handle(object1=object1, object2=object2, config=config) | ||
|
||
|
||
def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: | ||
r"""Get a mapping between the types and the equality comparators. | ||
Returns: | ||
The mapping between the types and the equality comparators. | ||
Example usage: | ||
```pycon | ||
>>> from coola.equality.comparators.scalar import get_type_comparator_mapping | ||
>>> get_type_comparator_mapping() | ||
{<class 'float'>: FloatEqualityComparator()} | ||
``` | ||
""" | ||
return {float: FloatEqualityComparator()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import TYPE_CHECKING | ||
|
||
import pytest | ||
|
||
from coola import objects_are_equal | ||
from tests.unit.equality.comparators.test_scalar import FLOAT_EQUAL, FLOAT_NOT_EQUAL | ||
|
||
if TYPE_CHECKING: | ||
from tests.unit.equality.comparators.utils import ExamplePair | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_EQUAL) | ||
@pytest.mark.parametrize("show_difference", [True, False]) | ||
def test_objects_are_equal_true( | ||
example: ExamplePair, show_difference: bool, caplog: pytest.LogCaptureFixture | ||
) -> None: | ||
with caplog.at_level(logging.INFO): | ||
assert objects_are_equal(example.object1, example.object2, show_difference) | ||
assert not caplog.messages | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) | ||
def test_objects_are_equal_false(example: ExamplePair, caplog: pytest.LogCaptureFixture) -> None: | ||
with caplog.at_level(logging.INFO): | ||
assert not objects_are_equal(example.object1, example.object2) | ||
assert not caplog.messages | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) | ||
def test_objects_are_equal_false_show_difference( | ||
example: ExamplePair, caplog: pytest.LogCaptureFixture | ||
) -> None: | ||
with caplog.at_level(logging.INFO): | ||
assert not objects_are_equal(example.object1, example.object2, show_difference=True) | ||
assert caplog.messages[-1].startswith(example.expected_message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
import pytest | ||
|
||
from coola.equality import EqualityConfig | ||
from coola.equality.comparators.scalar import ( | ||
FloatEqualityComparator, | ||
get_type_comparator_mapping, | ||
) | ||
from coola.equality.testers import EqualityTester | ||
from tests.unit.equality.comparators.utils import ExamplePair | ||
|
||
|
||
@pytest.fixture() | ||
def config() -> EqualityConfig: | ||
return EqualityConfig(tester=EqualityTester()) | ||
|
||
|
||
FLOAT_EQUAL = [ | ||
pytest.param(ExamplePair(object1=4.2, object2=4.2), id="positive"), | ||
pytest.param(ExamplePair(object1=0.0, object2=0.0), id="zero"), | ||
pytest.param(ExamplePair(object1=-4.2, object2=-4.2), id="negative"), | ||
pytest.param(ExamplePair(object1=float("inf"), object2=float("inf")), id="infinity"), | ||
pytest.param(ExamplePair(object1=float("-inf"), object2=float("-inf")), id="-infinity"), | ||
] | ||
|
||
|
||
FLOAT_NOT_EQUAL = [ | ||
pytest.param( | ||
ExamplePair(object1=4.2, object2=1.0, expected_message="numbers are not equal:"), | ||
id="different values", | ||
), | ||
pytest.param( | ||
ExamplePair(object1=4.2, object2="meow", expected_message="objects have different types:"), | ||
id="different types", | ||
), | ||
] | ||
|
||
|
||
############################################# | ||
# Tests for FloatEqualityComparator # | ||
############################################# | ||
|
||
|
||
def test_float_equality_comparator_str() -> None: | ||
assert str(FloatEqualityComparator()).startswith("FloatEqualityComparator(") | ||
|
||
|
||
def test_float_equality_comparator__eq__true() -> None: | ||
assert FloatEqualityComparator() == FloatEqualityComparator() | ||
|
||
|
||
def test_float_equality_comparator__eq__false_different_type() -> None: | ||
assert FloatEqualityComparator() != 123 | ||
|
||
|
||
def test_float_equality_comparator_clone() -> None: | ||
op = FloatEqualityComparator() | ||
op_cloned = op.clone() | ||
assert op is not op_cloned | ||
assert op == op_cloned | ||
|
||
|
||
def test_float_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None: | ||
x = 4.2 | ||
assert FloatEqualityComparator().equal(x, x, config) | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_EQUAL) | ||
def test_float_equality_comparator_equal_yes( | ||
example: ExamplePair, | ||
config: EqualityConfig, | ||
caplog: pytest.LogCaptureFixture, | ||
) -> None: | ||
comparator = FloatEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert comparator.equal(object1=example.object1, object2=example.object2, config=config) | ||
assert not caplog.messages | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_EQUAL) | ||
def test_float_equality_comparator_equal_yes_show_difference( | ||
example: ExamplePair, | ||
config: EqualityConfig, | ||
caplog: pytest.LogCaptureFixture, | ||
) -> None: | ||
config.show_difference = True | ||
comparator = FloatEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert comparator.equal(object1=example.object1, object2=example.object2, config=config) | ||
assert not caplog.messages | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) | ||
def test_float_equality_comparator_equal_false( | ||
example: ExamplePair, | ||
config: EqualityConfig, | ||
caplog: pytest.LogCaptureFixture, | ||
) -> None: | ||
comparator = FloatEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal(object1=example.object1, object2=example.object2, config=config) | ||
assert not caplog.messages | ||
|
||
|
||
@pytest.mark.parametrize("example", FLOAT_NOT_EQUAL) | ||
def test_float_equality_comparator_equal_false_show_difference( | ||
example: ExamplePair, | ||
config: EqualityConfig, | ||
caplog: pytest.LogCaptureFixture, | ||
) -> None: | ||
config.show_difference = True | ||
comparator = FloatEqualityComparator() | ||
with caplog.at_level(logging.INFO): | ||
assert not comparator.equal(object1=example.object1, object2=example.object2, config=config) | ||
assert caplog.messages[-1].startswith(example.expected_message) | ||
|
||
|
||
@pytest.mark.parametrize("equal_nan", [False, True]) | ||
def test_float_equality_comparator_equal_nan(config: EqualityConfig, equal_nan: bool) -> None: | ||
config.equal_nan = equal_nan | ||
assert ( | ||
FloatEqualityComparator().equal( | ||
object1=float("nan"), | ||
object2=float("nan"), | ||
config=config, | ||
) | ||
== equal_nan | ||
) | ||
|
||
|
||
################################################# | ||
# Tests for get_type_comparator_mapping # | ||
################################################# | ||
|
||
|
||
def test_get_type_comparator_mapping() -> None: | ||
assert get_type_comparator_mapping() == { | ||
float: FloatEqualityComparator(), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters