From 28ed765bd86edddc6b7e4d27ede48622984c8cf0 Mon Sep 17 00:00:00 2001 From: Thibaut Durand Date: Sun, 7 Jan 2024 20:01:21 -0800 Subject: [PATCH] Add `PandasDataFrameEqualityComparator` (#398) --- src/coola/equality/comparators/__init__.py | 6 +- src/coola/equality/comparators/pandas_.py | 52 ++++- src/coola/equality/handlers/pandas_.py | 8 +- .../unit/equality/comparators/test_pandas.py | 197 ++++++++++++++++-- 4 files changed, 239 insertions(+), 24 deletions(-) diff --git a/src/coola/equality/comparators/__init__.py b/src/coola/equality/comparators/__init__.py index c330098d..d37a95f2 100644 --- a/src/coola/equality/comparators/__init__.py +++ b/src/coola/equality/comparators/__init__.py @@ -8,6 +8,7 @@ "JaxArrayEqualityComparator", "MappingEqualityComparator", "NumpyArrayEqualityComparator", + "PandasDataFrameEqualityComparator", "PandasSeriesEqualityComparator", "SequenceEqualityComparator", "TorchPackedSequenceEqualityComparator", @@ -26,7 +27,10 @@ from coola.equality.comparators.default import DefaultEqualityComparator from coola.equality.comparators.jax_ import JaxArrayEqualityComparator from coola.equality.comparators.numpy_ import NumpyArrayEqualityComparator -from coola.equality.comparators.pandas_ import PandasSeriesEqualityComparator +from coola.equality.comparators.pandas_ import ( + PandasDataFrameEqualityComparator, + PandasSeriesEqualityComparator, +) from coola.equality.comparators.torch_ import ( TorchPackedSequenceEqualityComparator, TorchTensorEqualityComparator, diff --git a/src/coola/equality/comparators/pandas_.py b/src/coola/equality/comparators/pandas_.py index d74ef625..872bf0ca 100644 --- a/src/coola/equality/comparators/pandas_.py +++ b/src/coola/equality/comparators/pandas_.py @@ -11,6 +11,7 @@ from coola.equality.comparators.base import BaseEqualityComparator from coola.equality.handlers import ( + PandasDataFrameEqualHandler, PandasSeriesEqualHandler, SameObjectHandler, SameTypeHandler, @@ -28,6 +29,49 @@ logger = logging.getLogger(__name__) +class PandasDataFrameEqualityComparator(BaseEqualityComparator[pandas.DataFrame]): + r"""Implement an equality comparator for ``pandas.DataFrame``. + + Example usage: + + ```pycon + >>> import pandas as np + >>> from coola.equality import EqualityConfig + >>> from coola.equality.comparators import PandasDataFrameEqualityComparator + >>> from coola.testers import EqualityTester + >>> config = EqualityConfig(tester=EqualityTester()) + >>> comparator = PandasDataFrameEqualityComparator() + >>> comparator.equal( + ... pandas.DataFrame({"col": [1, 2, 3]}), + ... pandas.DataFrame({"col": [1, 2, 3]}), + ... config, + ... ) + True + >>> comparator.equal( + ... pandas.DataFrame({"col": [1, 2, 3]}), + ... pandas.DataFrame({"col": [1, 2, 4]}), + ... config, + ... ) + False + + ``` + """ + + def __init__(self) -> None: + check_pandas() + self._handler = SameObjectHandler() + self._handler.chain(SameTypeHandler()).chain(PandasDataFrameEqualHandler()) + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def clone(self) -> PandasDataFrameEqualityComparator: + return self.__class__() + + def equal(self, object1: pandas.DataFrame, object2: Any, config: EqualityConfig) -> bool: + return self._handler.handle(object1=object1, object2=object2, config=config) + + class PandasSeriesEqualityComparator(BaseEqualityComparator[pandas.Series]): r"""Implement an equality comparator for ``pandas.Series``. @@ -78,10 +122,14 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]: ```pycon >>> from coola.equality.comparators.pandas_ import get_type_comparator_mapping >>> get_type_comparator_mapping() - {: PandasSeriesEqualityComparator()} + {: PandasDataFrameEqualityComparator(), + : PandasSeriesEqualityComparator()} ``` """ if not is_pandas_available(): return {} - return {pandas.Series: PandasSeriesEqualityComparator()} + return { + pandas.DataFrame: PandasDataFrameEqualityComparator(), + pandas.Series: PandasSeriesEqualityComparator(), + } diff --git a/src/coola/equality/handlers/pandas_.py b/src/coola/equality/handlers/pandas_.py index 9cd18e4b..aedcff44 100644 --- a/src/coola/equality/handlers/pandas_.py +++ b/src/coola/equality/handlers/pandas_.py @@ -41,11 +41,15 @@ class PandasDataFrameEqualHandler(BaseEqualityHandler): >>> config = EqualityConfig(tester=EqualityTester()) >>> handler = PandasDataFrameEqualHandler() >>> handler.handle( - ... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 5]), config + ... pandas.DataFrame({"col": [1, 2, 3]}), + ... pandas.DataFrame({"col": [1, 2, 3]}), + ... config, ... ) True >>> handler.handle( - ... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 0]), config + ... pandas.DataFrame({"col": [1, 2, 3]}), + ... pandas.DataFrame({"col": [1, 2, 4]}), + ... config, ... ) False diff --git a/tests/unit/equality/comparators/test_pandas.py b/tests/unit/equality/comparators/test_pandas.py index 5090f9c6..2a9640a0 100644 --- a/tests/unit/equality/comparators/test_pandas.py +++ b/tests/unit/equality/comparators/test_pandas.py @@ -8,6 +8,7 @@ from coola import objects_are_equal from coola.equality import EqualityConfig from coola.equality.comparators.pandas_ import ( + PandasDataFrameEqualityComparator, PandasSeriesEqualityComparator, get_type_comparator_mapping, ) @@ -26,33 +27,188 @@ def config() -> EqualityConfig: return EqualityConfig(tester=EqualityTester()) -################################################## +####################################################### +# Tests for PandasDataFrameEqualityComparator # +####################################################### + + +@pandas_available +def test_objects_are_equal_dataframe() -> None: + assert objects_are_equal( + pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 3]}) + ) + + +@pandas_available +def test_pandas_dataframe_equality_comparator_str() -> None: + assert str(PandasDataFrameEqualityComparator()).startswith("PandasDataFrameEqualityComparator(") + + +@pandas_available +def test_pandas_dataframe_equality_comparator__eq__true() -> None: + assert PandasDataFrameEqualityComparator() == PandasDataFrameEqualityComparator() + + +@pandas_available +def test_pandas_dataframe_equality_comparator__eq__false_different_type() -> None: + assert PandasDataFrameEqualityComparator() != 123 + + +@pandas_available +def test_pandas_dataframe_equality_comparator_clone() -> None: + op = PandasDataFrameEqualityComparator() + op_cloned = op.clone() + assert op is not op_cloned + assert op == op_cloned + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_true_same_object( + config: EqualityConfig, +) -> None: + val = pandas.DataFrame({"col": [1, 2, 3]}) + assert PandasDataFrameEqualityComparator().equal(val, val, config) + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_true( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + comparator = PandasDataFrameEqualityComparator() + with caplog.at_level(logging.INFO): + assert comparator.equal( + object1=pandas.DataFrame({"col": [1, 2, 3]}), + object2=pandas.DataFrame({"col": [1, 2, 3]}), + config=config, + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_true_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = PandasDataFrameEqualityComparator() + with caplog.at_level(logging.INFO): + assert comparator.equal( + object1=pandas.DataFrame({"col": [1, 2, 3]}), + object2=pandas.DataFrame({"col": [1, 2, 3]}), + config=config, + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_false_different_value( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + comparator = PandasDataFrameEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + object1=pandas.DataFrame({"col": [1, 2, 3]}), + object2=pandas.DataFrame({"col": [1, 2, 4]}), + config=config, + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_false_different_value_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = PandasDataFrameEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + object1=pandas.DataFrame({"col": [1, 2, 3]}), + object2=pandas.DataFrame({"col": [1, 2, 4]}), + config=config, + ) + assert caplog.messages[0].startswith("pandas.DataFrames have different elements:") + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_false_different_type( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + comparator = PandasDataFrameEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + object1=pandas.DataFrame({"col": [1, 2, 3]}), object2=42, config=config + ) + assert not caplog.messages + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_false_different_type_show_difference( + caplog: pytest.LogCaptureFixture, config: EqualityConfig +) -> None: + config.show_difference = True + comparator = PandasDataFrameEqualityComparator() + with caplog.at_level(logging.INFO): + assert not comparator.equal( + object1=pandas.DataFrame({"col": [1, 2, 3]}), object2=42, config=config + ) + assert caplog.messages[0].startswith("objects have different types:") + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_nan_false(config: EqualityConfig) -> None: + assert not PandasDataFrameEqualityComparator().equal( + object1=pandas.DataFrame({"col": [1, float("nan"), 3]}), + object2=pandas.DataFrame({"col": [1, float("nan"), 3]}), + config=config, + ) + + +@pandas_available +def test_pandas_dataframe_equality_comparator_equal_nan_true(config: EqualityConfig) -> None: + config.equal_nan = True + assert PandasDataFrameEqualityComparator().equal( + object1=pandas.DataFrame({"col": [1, float("nan"), 3]}), + object2=pandas.DataFrame({"col": [1, float("nan"), 3]}), + config=config, + ) + + +@pandas_available +def test_pandas_dataframe_equality_comparator_no_pandas() -> None: + with patch( + "coola.utils.imports.is_pandas_available", lambda *args, **kwargs: False + ), pytest.raises(RuntimeError, match="`pandas` package is required but not installed."): + PandasDataFrameEqualityComparator() + + +#################################################### # Tests for PandasSeriesEqualityComparator # -################################################## +#################################################### @pandas_available -def test_objects_are_equal_array() -> None: - assert objects_are_equal(pandas.Series([1, 2, 3]), pandas.Series([1, 2, 3])) +def test_objects_are_equal_series() -> None: + assert objects_are_equal( + pandas.DataFrame({"col": [1, 2, 3]}), pandas.DataFrame({"col": [1, 2, 3]}) + ) @pandas_available -def test_pandas_array_equality_comparator_str() -> None: +def test_pandas_series_equality_comparator_str() -> None: assert str(PandasSeriesEqualityComparator()).startswith("PandasSeriesEqualityComparator(") @pandas_available -def test_pandas_array_equality_comparator__eq__true() -> None: +def test_pandas_series_equality_comparator__eq__true() -> None: assert PandasSeriesEqualityComparator() == PandasSeriesEqualityComparator() @pandas_available -def test_pandas_array_equality_comparator__eq__false_different_type() -> None: +def test_pandas_series_equality_comparator__eq__false_different_type() -> None: assert PandasSeriesEqualityComparator() != 123 @pandas_available -def test_pandas_array_equality_comparator_clone() -> None: +def test_pandas_series_equality_comparator_clone() -> None: op = PandasSeriesEqualityComparator() op_cloned = op.clone() assert op is not op_cloned @@ -60,13 +216,13 @@ def test_pandas_array_equality_comparator_clone() -> None: @pandas_available -def test_pandas_array_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None: +def test_pandas_series_equality_comparator_equal_true_same_object(config: EqualityConfig) -> None: series = pandas.Series([1, 2, 3]) assert PandasSeriesEqualityComparator().equal(series, series, config) @pandas_available -def test_pandas_array_equality_comparator_equal_true( +def test_pandas_series_equality_comparator_equal_true( caplog: pytest.LogCaptureFixture, config: EqualityConfig ) -> None: comparator = PandasSeriesEqualityComparator() @@ -80,7 +236,7 @@ def test_pandas_array_equality_comparator_equal_true( @pandas_available -def test_pandas_array_equality_comparator_equal_true_show_difference( +def test_pandas_series_equality_comparator_equal_true_show_difference( caplog: pytest.LogCaptureFixture, config: EqualityConfig ) -> None: config.show_difference = True @@ -95,7 +251,7 @@ def test_pandas_array_equality_comparator_equal_true_show_difference( @pandas_available -def test_pandas_array_equality_comparator_equal_false_different_value( +def test_pandas_series_equality_comparator_equal_false_different_value( caplog: pytest.LogCaptureFixture, config: EqualityConfig ) -> None: comparator = PandasSeriesEqualityComparator() @@ -107,7 +263,7 @@ def test_pandas_array_equality_comparator_equal_false_different_value( @pandas_available -def test_pandas_array_equality_comparator_equal_false_different_value_show_difference( +def test_pandas_series_equality_comparator_equal_false_different_value_show_difference( caplog: pytest.LogCaptureFixture, config: EqualityConfig ) -> None: config.show_difference = True @@ -120,7 +276,7 @@ def test_pandas_array_equality_comparator_equal_false_different_value_show_diffe @pandas_available -def test_pandas_array_equality_comparator_equal_false_different_type( +def test_pandas_series_equality_comparator_equal_false_different_type( caplog: pytest.LogCaptureFixture, config: EqualityConfig ) -> None: comparator = PandasSeriesEqualityComparator() @@ -130,7 +286,7 @@ def test_pandas_array_equality_comparator_equal_false_different_type( @pandas_available -def test_pandas_array_equality_comparator_equal_false_different_type_show_difference( +def test_pandas_series_equality_comparator_equal_false_different_type_show_difference( caplog: pytest.LogCaptureFixture, config: EqualityConfig ) -> None: config.show_difference = True @@ -141,7 +297,7 @@ def test_pandas_array_equality_comparator_equal_false_different_type_show_differ @pandas_available -def test_pandas_array_equality_comparator_equal_nan_false(config: EqualityConfig) -> None: +def test_pandas_series_equality_comparator_equal_nan_false(config: EqualityConfig) -> None: assert not PandasSeriesEqualityComparator().equal( object1=pandas.Series([0.0, float("nan"), float("nan"), 1.2]), object2=pandas.Series([0.0, float("nan"), float("nan"), 1.2]), @@ -150,7 +306,7 @@ def test_pandas_array_equality_comparator_equal_nan_false(config: EqualityConfig @pandas_available -def test_pandas_array_equality_comparator_equal_nan_true(config: EqualityConfig) -> None: +def test_pandas_series_equality_comparator_equal_nan_true(config: EqualityConfig) -> None: config.equal_nan = True assert PandasSeriesEqualityComparator().equal( object1=pandas.Series([0.0, float("nan"), float("nan"), 1.2]), @@ -160,7 +316,7 @@ def test_pandas_array_equality_comparator_equal_nan_true(config: EqualityConfig) @pandas_available -def test_pandas_array_equality_comparator_no_pandas() -> None: +def test_pandas_series_equality_comparator_no_pandas() -> None: with patch( "coola.utils.imports.is_pandas_available", lambda *args, **kwargs: False ), pytest.raises(RuntimeError, match="`pandas` package is required but not installed."): @@ -174,7 +330,10 @@ def test_pandas_array_equality_comparator_no_pandas() -> None: @pandas_available def test_get_type_comparator_mapping() -> None: - assert get_type_comparator_mapping() == {pandas.Series: PandasSeriesEqualityComparator()} + assert get_type_comparator_mapping() == { + pandas.DataFrame: PandasDataFrameEqualityComparator(), + pandas.Series: PandasSeriesEqualityComparator(), + } def test_get_type_comparator_mapping_no_pandas() -> None: