Skip to content

Commit

Permalink
Add NanEqualHandler (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 11, 2024
1 parent 4d24a32 commit 0b16bfd
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/coola/equality/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
"BaseEqualityHandler",
"EqualHandler",
"FalseHandler",
"ScalarEqualHandler",
"JaxArrayEqualHandler",
"MappingSameKeysHandler",
"MappingSameValuesHandler",
"NanEqualHandler",
"NumpyArrayEqualHandler",
"ObjectEqualHandler",
"PandasDataFrameEqualHandler",
Expand All @@ -28,6 +28,7 @@
"SameObjectHandler",
"SameShapeHandler",
"SameTypeHandler",
"ScalarEqualHandler",
"SequenceSameValuesHandler",
"TorchTensorEqualHandler",
"TrueHandler",
Expand Down Expand Up @@ -60,7 +61,7 @@
PolarsDataFrameEqualHandler,
PolarsSeriesEqualHandler,
)
from coola.equality.handlers.scalar import ScalarEqualHandler
from coola.equality.handlers.scalar import NanEqualHandler, ScalarEqualHandler
from coola.equality.handlers.sequence import SequenceSameValuesHandler
from coola.equality.handlers.shape import SameShapeHandler
from coola.equality.handlers.torch_ import TorchTensorEqualHandler
41 changes: 39 additions & 2 deletions src/coola/equality/handlers/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from __future__ import annotations

__all__ = ["ScalarEqualHandler"]
__all__ = ["NanEqualHandler", "ScalarEqualHandler"]

import logging
import math
from typing import TYPE_CHECKING

from coola.equality.handlers.base import BaseEqualityHandler
from coola.equality.handlers.base import AbstractEqualityHandler, BaseEqualityHandler

if TYPE_CHECKING:
from coola.equality.config import EqualityConfig
Expand All @@ -17,6 +17,43 @@
logger = logging.getLogger(__name__)


class NanEqualHandler(AbstractEqualityHandler):
r"""Check if the two NaNs are equal.
This handler returns ``True`` if the two numbers are NaNs,
otherwise it passes the inputs to the next handler.
Example usage:
```pycon
>>> from coola.equality import EqualityConfig
>>> from coola.equality.handlers import NanEqualHandler
>>> from coola.equality.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> handler = NanEqualHandler()
>>> handler.handle(float("nan"), float("nan"), config)
False
>>> config.equal_nan = True
>>> handler.handle(float("nan"), float("nan"), config)
True
```
"""

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)

def handle(
self,
object1: float,
object2: float,
config: EqualityConfig,
) -> bool:
if config.equal_nan and math.isnan(object1) and math.isnan(object2):
return True
return self._handle_next(object1=object1, object2=object2, config=config)


class ScalarEqualHandler(BaseEqualityHandler):
r"""Check if the two numbers are equal or not.
Expand Down
60 changes: 59 additions & 1 deletion tests/unit/equality/handlers/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from coola.equality import EqualityConfig
from coola.equality.handlers import FalseHandler, ScalarEqualHandler
from coola.equality.handlers import FalseHandler, NanEqualHandler, ScalarEqualHandler
from coola.equality.testers import EqualityTester
from tests.unit.equality.comparators.utils import ExamplePair

Expand All @@ -15,6 +15,64 @@ def config() -> EqualityConfig:
return EqualityConfig(tester=EqualityTester())


#####################################
# Tests for NanEqualHandler #
#####################################


def test_nan_equal_handler_eq_true() -> None:
assert NanEqualHandler() == NanEqualHandler()


def test_nan_equal_handler_eq_false() -> None:
assert NanEqualHandler() != FalseHandler()


def test_nan_equal_handler_repr() -> None:
assert repr(NanEqualHandler()).startswith("NanEqualHandler(")


def test_nan_equal_handler_str() -> None:
assert str(NanEqualHandler()) == "NanEqualHandler()"


def test_nan_equal_handler_handle_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert NanEqualHandler().handle(float("nan"), float("nan"), config)


@pytest.mark.parametrize(
("object1", "object2"),
[
(float("nan"), float("nan")),
(4.2, 4.2),
(1, 0),
],
)
def test_nan_equal_handler_handle_false(
object1: float, object2: float, config: EqualityConfig
) -> None:
assert not NanEqualHandler(next_handler=FalseHandler()).handle(object1, object2, config)


def test_nan_equal_handler_handle_without_next_handler(config: EqualityConfig) -> None:
handler = NanEqualHandler()
with pytest.raises(RuntimeError, match="next handler is not defined"):
handler.handle(object1=42, object2=42, config=config)


def test_nan_equal_handler_set_next_handler() -> None:
handler = NanEqualHandler()
handler.set_next_handler(FalseHandler())
assert handler.next_handler == FalseHandler()


def test_nan_equal_handler_set_next_handler_incorrect() -> None:
handler = NanEqualHandler()
with pytest.raises(TypeError, match="Incorrect type for `handler`."):
handler.set_next_handler(None)


########################################
# Tests for ScalarEqualHandler #
########################################
Expand Down

0 comments on commit 0b16bfd

Please sign in to comment.