Skip to content

Commit

Permalink
Add SameAttributeHandler (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 7, 2024
1 parent 3bb2654 commit dcee4ec
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/coola/equality/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"MappingSameValuesHandler",
"NumpyArrayEqualHandler",
"ObjectEqualHandler",
"SameAttributeHandler",
"SameDTypeHandler",
"SameLengthHandler",
"SameObjectHandler",
Expand All @@ -35,6 +36,7 @@
from coola.equality.handlers.native import (
FalseHandler,
ObjectEqualHandler,
SameAttributeHandler,
SameLengthHandler,
SameObjectHandler,
SameTypeHandler,
Expand Down
5 changes: 5 additions & 0 deletions src/coola/equality/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def next_handler(self) -> BaseEqualityHandler | None:
def _handle_next(self, object1: Any, object2: Any, config: EqualityConfig) -> bool:
r"""Return the output from the next handler.
Args:
object1: Specifies the first object to compare.
object2: Specifies the second object to compare.
config: Specifies the equality configuration.
Returns:
The output from the next handler.
Expand Down
56 changes: 56 additions & 0 deletions src/coola/equality/handlers/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = [
"FalseHandler",
"ObjectEqualHandler",
"SameAttributeHandler",
"SameLengthHandler",
"SameObjectHandler",
"SameTypeHandler",
Expand All @@ -15,6 +16,7 @@
from typing import TYPE_CHECKING, Any

from coola.equality.handlers.base import AbstractEqualityHandler, BaseEqualityHandler
from coola.utils import repr_indent, repr_mapping

if TYPE_CHECKING:
from collections.abc import Sized
Expand Down Expand Up @@ -150,6 +152,60 @@ def set_next_handler(self, handler: BaseEqualityHandler) -> None:
pass # Do nothing because the next handler is never called.


class SameAttributeHandler(AbstractEqualityHandler):
r"""Check if the two objects have the same attribute.
This handler returns ``False`` if the two objects have different
attributes, otherwise it passes the inputs to the next handler.
The objects must have the attribute.
Example usage:
```pycon
>>> import numpy as np
>>> from coola.equality import EqualityConfig
>>> from coola.equality.handlers import SameAttributeHandler, TrueHandler
>>> from coola.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> handler = SameAttributeHandler(name="shape", next_handler=TrueHandler())
>>> handler.handle(np.ones((2, 3)), np.ones((2, 3)), config)
True
>>> handler.handle(np.ones((2, 3)), np.ones((3, 2)), config)
False
```
"""

def __init__(self, name: str, next_handler: BaseEqualityHandler | None = None) -> None:
super().__init__(next_handler=next_handler)
self._name = name

def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return self.name == other.name

def __repr__(self) -> str:
args = repr_indent(repr_mapping({"name": self._name, "next_handler": self._next_handler}))
return f"{self.__class__.__qualname__}(\n {args}\n)"

def __str__(self) -> str:
return f"{self.__class__.__qualname__}(name={self._name})"

@property
def name(self) -> str:
return self._name

def handle(self, object1: Any, object2: Any, config: EqualityConfig) -> bool:
value1 = getattr(object1, self._name)
value2 = getattr(object2, self._name)
if not config.tester.equal(value1, value2, config.show_difference):
if config.show_difference:
logger.info(f"objects have different {self._name}: {value1} vs {value2}")
return False
return self._handle_next(object1=object1, object2=object2, config=config)


class SameLengthHandler(AbstractEqualityHandler):
r"""Check if the two objects have the same length.
Expand Down
114 changes: 114 additions & 0 deletions tests/unit/equality/handlers/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from typing import TYPE_CHECKING, Any
from unittest.mock import Mock

import pytest

Expand All @@ -10,11 +11,19 @@
from coola.equality.handlers import (
FalseHandler,
ObjectEqualHandler,
SameAttributeHandler,
SameLengthHandler,
SameObjectHandler,
SameTypeHandler,
TrueHandler,
)
from coola.testing import numpy_available
from coola.utils import is_numpy_available

if is_numpy_available():
import numpy as np
else:
np = Mock()

if TYPE_CHECKING:
from collections.abc import Sized
Expand Down Expand Up @@ -126,6 +135,111 @@ def test_object_equal_handler_set_next_handler() -> None:
ObjectEqualHandler().set_next_handler(FalseHandler())


##########################################
# Tests for SameAttributeHandler #
##########################################


def test_same_attribute_handler_eq_true() -> None:
assert SameAttributeHandler(name="name") == SameAttributeHandler(name="name")


def test_same_attribute_handler_eq_false_different_type() -> None:
assert SameAttributeHandler(name="name") != FalseHandler()


def test_same_attribute_handler_eq_false_different_name() -> None:
assert SameAttributeHandler(name="name1") != SameAttributeHandler(name="name2")


def test_same_attribute_handler_repr() -> None:
assert repr(SameAttributeHandler(name="name")).startswith("SameAttributeHandler(")


def test_same_attribute_handler_str() -> None:
assert str(SameAttributeHandler(name="name")).startswith("SameAttributeHandler(")


@pytest.mark.parametrize(
("object1", "object2"),
[
(Mock(data=1), Mock(data=1)),
(Mock(data="abc"), Mock(data="abc")),
(Mock(data=[1, 2, 3]), Mock(data=[1, 2, 3])),
],
)
def test_same_attribute_handler_handle_true(
object1: Any, object2: Any, config: EqualityConfig
) -> None:
assert SameAttributeHandler(name="data", next_handler=TrueHandler()).handle(
object1, object2, config
)


@pytest.mark.parametrize(
("object1", "object2"),
[
(Mock(data=1), Mock(data=2)),
(Mock(data="abc"), Mock(data="abcd")),
(Mock(data=[1, 2, 3]), Mock(data=[1, 2, 4])),
],
)
def test_same_attribute_handler_handle_false(
object1: Any, object2: Any, config: EqualityConfig
) -> None:
assert not SameAttributeHandler(name="data").handle(object1, object2, config)


@numpy_available
def test_same_attribute_handler_handle_false_show_difference(
config: EqualityConfig, caplog: pytest.LogCaptureFixture
) -> None:
config.show_difference = True
handler = SameAttributeHandler(name="data")
with caplog.at_level(logging.INFO):
assert not handler.handle(
object1=Mock(data=1),
object2=Mock(data=2),
config=config,
)
assert caplog.messages[-1].startswith("objects have different data:")


@numpy_available
def test_same_attribute_handler_handle_without_next_handler(config: EqualityConfig) -> None:
handler = SameAttributeHandler(name="data")
with pytest.raises(RuntimeError, match="next handler is not defined"):
handler.handle(
object1=Mock(spec=Any, data=1), object2=Mock(spec=Any, data=1), config=config
)


def test_same_attribute_handler_set_next_handler() -> None:
handler = SameAttributeHandler(name="data")
handler.set_next_handler(FalseHandler())
assert handler.next_handler == FalseHandler()


def test_same_attribute_handler_set_next_handler_incorrect() -> None:
handler = SameAttributeHandler(name="data")
with pytest.raises(TypeError, match="Incorrect type for `handler`."):
handler.set_next_handler(None)


@numpy_available
def test_same_attribute_handler_handle_true_numpy(config: EqualityConfig) -> None:
assert SameAttributeHandler(name="dtype", next_handler=TrueHandler()).handle(
np.ones(shape=(2, 3)), np.ones(shape=(2, 3)), config
)


@numpy_available
def test_same_attribute_handler_handle_false_numpy(config: EqualityConfig) -> None:
assert not SameAttributeHandler(name="dtype", next_handler=TrueHandler()).handle(
np.ones(shape=(2, 3), dtype=float), np.ones(shape=(2, 3), dtype=int), config
)


#######################################
# Tests for SameLengthHandler #
#######################################
Expand Down

0 comments on commit dcee4ec

Please sign in to comment.