diff --git a/src/coola/equality/handlers/__init__.py b/src/coola/equality/handlers/__init__.py index 6d181376..5b7566d1 100644 --- a/src/coola/equality/handlers/__init__.py +++ b/src/coola/equality/handlers/__init__.py @@ -15,6 +15,7 @@ "MappingSameValuesHandler", "NumpyArrayEqualHandler", "ObjectEqualHandler", + "PandasDataFrameEqualHandler", "PandasSeriesEqualHandler", "SameAttributeHandler", "SameDTypeHandler", @@ -46,7 +47,10 @@ TrueHandler, ) from coola.equality.handlers.numpy_ import NumpyArrayEqualHandler -from coola.equality.handlers.pandas_ import PandasSeriesEqualHandler +from coola.equality.handlers.pandas_ import ( + PandasDataFrameEqualHandler, + PandasSeriesEqualHandler, +) from coola.equality.handlers.sequence import SequenceSameValuesHandler from coola.equality.handlers.shape import SameShapeHandler from coola.equality.handlers.torch_ import TorchTensorEqualHandler diff --git a/src/coola/equality/handlers/pandas_.py b/src/coola/equality/handlers/pandas_.py index 4e27eae1..9e346098 100644 --- a/src/coola/equality/handlers/pandas_.py +++ b/src/coola/equality/handlers/pandas_.py @@ -3,7 +3,7 @@ from __future__ import annotations -__all__ = ["PandasSeriesEqualHandler"] +__all__ = ["PandasDataFrameEqualHandler", "PandasSeriesEqualHandler"] import logging from typing import TYPE_CHECKING @@ -23,6 +23,80 @@ logger = logging.getLogger(__name__) +class PandasDataFrameEqualHandler(BaseEqualityHandler): + r"""Check if the two ``pandas.DataFrame`` are equal. + + This handler returns ``True`` if the two ``pandas.DataFrame``s + equal, otherwise ``False``. This handler is designed to be used + at the end of the chain of responsibility. This handler does + not call the next handler. + + Example usage: + + ```pycon + >>> import pandas + >>> from coola.equality import EqualityConfig + >>> from coola.equality.handlers import PandasDataFrameEqualHandler + >>> from coola.testers import EqualityTester + >>> config = EqualityConfig(tester=EqualityTester()) + >>> handler = PandasDataFrameEqualHandler() + >>> handler.handle( + ... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 5]), config + ... ) + True + >>> handler.handle( + ... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 0]), config + ... ) + False + + ``` + """ + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}()" + + def handle( + self, + object1: pandas.DataFrame, + object2: pandas.DataFrame, + config: EqualityConfig, + ) -> bool: + object_equal = self._compare_dataframes(object1, object2, config) + if config.show_difference and not object_equal: + logger.info( + f"pandas.DataFrames have different elements:\n" + f"object1:\n{object1}\nobject2:\n{object2}" + ) + return object_equal + + def set_next_handler(self, handler: BaseEqualityHandler) -> None: + pass # Do nothing because the next handler is never called. + + def _compare_dataframes( + self, df1: pandas.DataFrame, df2: pandas.DataFrame, config: EqualityConfig + ) -> bool: + r"""Indicate if the two series are equal or not. + + Args: + df1: Specifies the first DataFrame to compare. + df2: Specifies the second DataFrame to compare. + config: Specifies the equality configuration. + + Returns: + ``True``if the two DataFrame are equal, otherwise ``False``. + """ + if not config.equal_nan and df1.isna().any().any(): + return False + try: + pandas.testing.assert_frame_equal(df1, df2, check_exact=True) + except AssertionError: + return False + return True + + class PandasSeriesEqualHandler(BaseEqualityHandler): r"""Check if the two ``pandas.Series`` are equal. diff --git a/tests/unit/equality/handlers/test_pandas.py b/tests/unit/equality/handlers/test_pandas.py index bcfacaff..9df805da 100644 --- a/tests/unit/equality/handlers/test_pandas.py +++ b/tests/unit/equality/handlers/test_pandas.py @@ -7,7 +7,11 @@ from coola import EqualityTester from coola.equality import EqualityConfig -from coola.equality.handlers import FalseHandler, PandasSeriesEqualHandler +from coola.equality.handlers import ( + FalseHandler, + PandasDataFrameEqualHandler, + PandasSeriesEqualHandler, +) from coola.testing import pandas_available from coola.utils import is_pandas_available @@ -22,6 +26,169 @@ def config() -> EqualityConfig: return EqualityConfig(tester=EqualityTester()) +################################################# +# Tests for PandasDataFrameEqualHandler # +################################################# + + +def test_pandas_dataframe_equal_handler_eq_true() -> None: + assert PandasDataFrameEqualHandler() == PandasDataFrameEqualHandler() + + +def test_pandas_dataframe_equal_handler_eq_false() -> None: + assert PandasDataFrameEqualHandler() != FalseHandler() + + +def test_pandas_dataframe_equal_handler_repr() -> None: + assert repr(PandasDataFrameEqualHandler()).startswith("PandasDataFrameEqualHandler(") + + +def test_pandas_dataframe_equal_handler_str() -> None: + assert str(PandasDataFrameEqualHandler()).startswith("PandasDataFrameEqualHandler(") + + +@pandas_available +@pytest.mark.parametrize( + ("object1", "object2"), + [ + (pandas.DataFrame({}), pandas.DataFrame({})), + (pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 3]})), + ( + pandas.DataFrame( + { + "col1": [1, 2, 3, 4, 5], + "col2": [1.1, 2.2, 3.3, 4.4, 5.5], + "col3": ["a", "b", "c", "d", "e"], + } + ), + pandas.DataFrame( + { + "col1": [1, 2, 3, 4, 5], + "col2": [1.1, 2.2, 3.3, 4.4, 5.5], + "col3": ["a", "b", "c", "d", "e"], + } + ), + ), + ], +) +def test_pandas_dataframe_equal_handler_handle_true( + object1: pandas.DataFrame, + object2: pandas.DataFrame, + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert handler.handle(object1, object2, config) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_true_show_difference( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + config.show_difference = True + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert handler.handle( + pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 3]}), config + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_false( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.DataFrame({}), pandas.DataFrame({"col": [1, 2, 3]}), config + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_false_different_column( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.DataFrame({"col1": [1, 2, 3]}), pandas.DataFrame({"col2": [1, 2, 3]}), config + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_false_different_value( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 4]}), config + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_false_different_dtype( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.DataFrame(data={"col": [1, 2, 3]}, dtype=float), + pandas.DataFrame(data={"col": [1, 2, 3]}, dtype=int), + config, + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_false_show_difference( + config: EqualityConfig, caplog: pytest.LogCaptureFixture +) -> None: + config.show_difference = True + handler = PandasDataFrameEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.DataFrame({"col": [1, 2, 3]}), + pandas.DataFrame({"col": [1, 2, 4]}), + config=config, + ) + assert caplog.messages[0].startswith("pandas.DataFrames have different elements:") + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_equal_nan_false(config: EqualityConfig) -> None: + assert not PandasDataFrameEqualHandler().handle( + pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}), + pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}), + config, + ) + + +@pandas_available +def test_pandas_dataframe_equal_handler_handle_equal_nan_true(config: EqualityConfig) -> None: + config.equal_nan = True + assert PandasDataFrameEqualHandler().handle( + pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}), + pandas.DataFrame({"col": [0.0, float("nan"), float("nan"), 1.2]}), + config, + ) + + +def test_pandas_dataframe_equal_handler_set_next_handler() -> None: + PandasDataFrameEqualHandler().set_next_handler(FalseHandler()) + + ############################################## # Tests for PandasSeriesEqualHandler # ############################################## @@ -77,23 +244,43 @@ def test_pandas_series_equal_handler_handle_true_show_difference( @pandas_available -@pytest.mark.parametrize( - ("object1", "object2"), - [ - (pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 3, 4])), - (pandas.Series(data=[1, 2, 3], dtype=int), pandas.Series(data=[1, 2, 3], dtype=float)), - (pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 4])), - ], -) -def test_pandas_series_equal_handler_handle_false( - object1: pandas.Series, - object2: pandas.Series, +def test_pandas_series_equal_handler_handle_false_different_shape( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasSeriesEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 3, 4]), config + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_series_equal_handler_handle_false_different_dtype( + config: EqualityConfig, + caplog: pytest.LogCaptureFixture, +) -> None: + handler = PandasSeriesEqualHandler() + with caplog.at_level(logging.INFO): + assert not handler.handle( + pandas.Series(data=[1, 2, 3], dtype=int), + pandas.Series(data=[1, 2, 3], dtype=float), + config, + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_series_equal_handler_handle_false_different_value( config: EqualityConfig, caplog: pytest.LogCaptureFixture, ) -> None: handler = PandasSeriesEqualHandler() with caplog.at_level(logging.INFO): - assert not handler.handle(object1, object2, config) + assert not handler.handle( + pandas.Series(data=[1, 2, 3]), pandas.Series(data=[1, 2, 4]), config + ) assert not caplog.messages