Skip to content

Commit

Permalink
Add PandasDataFrameEqualHandler (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 8, 2024
1 parent 7c77d2e commit 87e70fa
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 15 deletions.
6 changes: 5 additions & 1 deletion src/coola/equality/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"MappingSameValuesHandler",
"NumpyArrayEqualHandler",
"ObjectEqualHandler",
"PandasDataFrameEqualHandler",
"PandasSeriesEqualHandler",
"SameAttributeHandler",
"SameDTypeHandler",
Expand Down Expand Up @@ -46,7 +47,10 @@
TrueHandler,
)
from coola.equality.handlers.numpy_ import NumpyArrayEqualHandler
from coola.equality.handlers.pandas_ import PandasSeriesEqualHandler
from coola.equality.handlers.pandas_ import (
PandasDataFrameEqualHandler,
PandasSeriesEqualHandler,
)
from coola.equality.handlers.sequence import SequenceSameValuesHandler
from coola.equality.handlers.shape import SameShapeHandler
from coola.equality.handlers.torch_ import TorchTensorEqualHandler
76 changes: 75 additions & 1 deletion src/coola/equality/handlers/pandas_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

__all__ = ["PandasSeriesEqualHandler"]
__all__ = ["PandasDataFrameEqualHandler", "PandasSeriesEqualHandler"]

import logging
from typing import TYPE_CHECKING
Expand All @@ -23,6 +23,80 @@
logger = logging.getLogger(__name__)


class PandasDataFrameEqualHandler(BaseEqualityHandler):
r"""Check if the two ``pandas.DataFrame`` are equal.
This handler returns ``True`` if the two ``pandas.DataFrame``s
equal, otherwise ``False``. This handler is designed to be used
at the end of the chain of responsibility. This handler does
not call the next handler.
Example usage:
```pycon
>>> import pandas
>>> from coola.equality import EqualityConfig
>>> from coola.equality.handlers import PandasDataFrameEqualHandler
>>> from coola.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> handler = PandasDataFrameEqualHandler()
>>> handler.handle(
... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 5]), config
... )
True
>>> handler.handle(
... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 0]), config
... )
False
```
"""

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}()"

def handle(
self,
object1: pandas.DataFrame,
object2: pandas.DataFrame,
config: EqualityConfig,
) -> bool:
object_equal = self._compare_dataframes(object1, object2, config)
if config.show_difference and not object_equal:
logger.info(
f"pandas.DataFrames have different elements:\n"
f"object1:\n{object1}\nobject2:\n{object2}"
)
return object_equal

def set_next_handler(self, handler: BaseEqualityHandler) -> None:
pass # Do nothing because the next handler is never called.

def _compare_dataframes(
self, df1: pandas.DataFrame, df2: pandas.DataFrame, config: EqualityConfig
) -> bool:
r"""Indicate if the two series are equal or not.
Args:
df1: Specifies the first DataFrame to compare.
df2: Specifies the second DataFrame to compare.
config: Specifies the equality configuration.
Returns:
``True``if the two DataFrame are equal, otherwise ``False``.
"""
if not config.equal_nan and df1.isna().any().any():
return False
try:
pandas.testing.assert_frame_equal(df1, df2, check_exact=True)
except AssertionError:
return False
return True


class PandasSeriesEqualHandler(BaseEqualityHandler):
r"""Check if the two ``pandas.Series`` are equal.
Expand Down
213 changes: 200 additions & 13 deletions tests/unit/equality/handlers/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from coola import EqualityTester
from coola.equality import EqualityConfig
from coola.equality.handlers import FalseHandler, PandasSeriesEqualHandler
from coola.equality.handlers import (
FalseHandler,
PandasDataFrameEqualHandler,
PandasSeriesEqualHandler,
)
from coola.testing import pandas_available
from coola.utils import is_pandas_available

Expand All @@ -22,6 +26,169 @@ def config() -> EqualityConfig:
return EqualityConfig(tester=EqualityTester())


#################################################
# Tests for PandasDataFrameEqualHandler #
#################################################


def test_pandas_dataframe_equal_handler_eq_true() -> None:
assert PandasDataFrameEqualHandler() == PandasDataFrameEqualHandler()


def test_pandas_dataframe_equal_handler_eq_false() -> None:
assert PandasDataFrameEqualHandler() != FalseHandler()


def test_pandas_dataframe_equal_handler_repr() -> None:
assert repr(PandasDataFrameEqualHandler()).startswith("PandasDataFrameEqualHandler(")


def test_pandas_dataframe_equal_handler_str() -> None:
assert str(PandasDataFrameEqualHandler()).startswith("PandasDataFrameEqualHandler(")


@pandas_available
@pytest.mark.parametrize(
("object1", "object2"),
[
(pandas.DataFrame({}), pandas.DataFrame({})),
(pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 3]})),
(
pandas.DataFrame(
{
"col1": [1, 2, 3, 4, 5],
"col2": [1.1, 2.2, 3.3, 4.4, 5.5],
"col3": ["a", "b", "c", "d", "e"],
}
),
pandas.DataFrame(
{
"col1": [1, 2, 3, 4, 5],
"col2": [1.1, 2.2, 3.3, 4.4, 5.5],
"col3": ["a", "b", "c", "d", "e"],
}
),
),
],
)
def test_pandas_dataframe_equal_handler_handle_true(
object1: pandas.DataFrame,
object2: pandas.DataFrame,
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert handler.handle(object1, object2, config)
assert not caplog.messages


@pandas_available
def test_pandas_dataframe_equal_handler_handle_true_show_difference(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
config.show_difference = True
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert handler.handle(
pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 3]}), config
)
assert not caplog.messages


@pandas_available
def test_pandas_dataframe_equal_handler_handle_false(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.DataFrame({}), pandas.DataFrame({"col": [1, 2, 3]}), config
)
assert not caplog.messages


@pandas_available
def test_pandas_dataframe_equal_handler_handle_false_different_column(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.DataFrame({"col1": [1, 2, 3]}), pandas.DataFrame({"col2": [1, 2, 3]}), config
)
assert not caplog.messages


@pandas_available
def test_pandas_dataframe_equal_handler_handle_false_different_value(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 4]}), config
)
assert not caplog.messages


@pandas_available
def test_pandas_dataframe_equal_handler_handle_false_different_dtype(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.DataFrame(data={"col": [1, 2, 3]}, dtype=float),
pandas.DataFrame(data={"col": [1, 2, 3]}, dtype=int),
config,
)
assert not caplog.messages


@pandas_available
def test_pandas_dataframe_equal_handler_handle_false_show_difference(
config: EqualityConfig, caplog: pytest.LogCaptureFixture
) -> None:
config.show_difference = True
handler = PandasDataFrameEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.DataFrame({"col": [1, 2, 3]}),
pandas.DataFrame({"col": [1, 2, 4]}),
config=config,
)
assert caplog.messages[0].startswith("pandas.DataFrames have different elements:")


@pandas_available
def test_pandas_dataframe_equal_handler_handle_equal_nan_false(config: EqualityConfig) -> None:
assert not PandasDataFrameEqualHandler().handle(
pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}),
pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}),
config,
)


@pandas_available
def test_pandas_dataframe_equal_handler_handle_equal_nan_true(config: EqualityConfig) -> None:
config.equal_nan = True
assert PandasDataFrameEqualHandler().handle(
pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}),
pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}),
config,
)


def test_pandas_dataframe_equal_handler_set_next_handler() -> None:
PandasDataFrameEqualHandler().set_next_handler(FalseHandler())


##############################################
# Tests for PandasSeriesEqualHandler #
##############################################
Expand Down Expand Up @@ -77,23 +244,43 @@ def test_pandas_series_equal_handler_handle_true_show_difference(


@pandas_available
@pytest.mark.parametrize(
("object1", "object2"),
[
(pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 3, 4])),
(pandas.Series(data=[1, 2, 3], dtype=int), pandas.Series(data=[1, 2, 3], dtype=float)),
(pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 4])),
],
)
def test_pandas_series_equal_handler_handle_false(
object1: pandas.Series,
object2: pandas.Series,
def test_pandas_series_equal_handler_handle_false_different_shape(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasSeriesEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 3, 4]), config
)
assert not caplog.messages


@pandas_available
def test_pandas_series_equal_handler_handle_false_different_dtype(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasSeriesEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(
pandas.Series(data=[1, 2, 3], dtype=int),
pandas.Series(data=[1, 2, 3], dtype=float),
config,
)
assert not caplog.messages


@pandas_available
def test_pandas_series_equal_handler_handle_false_different_value(
config: EqualityConfig,
caplog: pytest.LogCaptureFixture,
) -> None:
handler = PandasSeriesEqualHandler()
with caplog.at_level(logging.INFO):
assert not handler.handle(object1, object2, config)
assert not handler.handle(
pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 4]), config
)
assert not caplog.messages


Expand Down

0 comments on commit 87e70fa

Please sign in to comment.