Skip to content

Commit

Permalink
Add PandasDataFrameEqualityComparator (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Jan 8, 2024
1 parent c78803e commit 28ed765
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 24 deletions.
6 changes: 5 additions & 1 deletion src/coola/equality/comparators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"JaxArrayEqualityComparator",
"MappingEqualityComparator",
"NumpyArrayEqualityComparator",
"PandasDataFrameEqualityComparator",
"PandasSeriesEqualityComparator",
"SequenceEqualityComparator",
"TorchPackedSequenceEqualityComparator",
Expand All @@ -26,7 +27,10 @@
from coola.equality.comparators.default import DefaultEqualityComparator
from coola.equality.comparators.jax_ import JaxArrayEqualityComparator
from coola.equality.comparators.numpy_ import NumpyArrayEqualityComparator
from coola.equality.comparators.pandas_ import PandasSeriesEqualityComparator
from coola.equality.comparators.pandas_ import (
PandasDataFrameEqualityComparator,
PandasSeriesEqualityComparator,
)
from coola.equality.comparators.torch_ import (
TorchPackedSequenceEqualityComparator,
TorchTensorEqualityComparator,
Expand Down
52 changes: 50 additions & 2 deletions src/coola/equality/comparators/pandas_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from coola.equality.comparators.base import BaseEqualityComparator
from coola.equality.handlers import (
PandasDataFrameEqualHandler,
PandasSeriesEqualHandler,
SameObjectHandler,
SameTypeHandler,
Expand All @@ -28,6 +29,49 @@
logger = logging.getLogger(__name__)


class PandasDataFrameEqualityComparator(BaseEqualityComparator[pandas.DataFrame]):
r"""Implement an equality comparator for ``pandas.DataFrame``.
Example usage:
```pycon
>>> import pandas as np
>>> from coola.equality import EqualityConfig
>>> from coola.equality.comparators import PandasDataFrameEqualityComparator
>>> from coola.testers import EqualityTester
>>> config = EqualityConfig(tester=EqualityTester())
>>> comparator = PandasDataFrameEqualityComparator()
>>> comparator.equal(
... pandas.DataFrame({"col": [1, 2, 3]}),
... pandas.DataFrame({"col": [1, 2, 3]}),
... config,
... )
True
>>> comparator.equal(
... pandas.DataFrame({"col": [1, 2, 3]}),
... pandas.DataFrame({"col": [1, 2, 4]}),
... config,
... )
False
```
"""

def __init__(self) -> None:
check_pandas()
self._handler = SameObjectHandler()
self._handler.chain(SameTypeHandler()).chain(PandasDataFrameEqualHandler())

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)

def clone(self) -> PandasDataFrameEqualityComparator:
return self.__class__()

def equal(self, object1: pandas.DataFrame, object2: Any, config: EqualityConfig) -> bool:
return self._handler.handle(object1=object1, object2=object2, config=config)


class PandasSeriesEqualityComparator(BaseEqualityComparator[pandas.Series]):
r"""Implement an equality comparator for ``pandas.Series``.
Expand Down Expand Up @@ -78,10 +122,14 @@ def get_type_comparator_mapping() -> dict[type, BaseEqualityComparator]:
```pycon
>>> from coola.equality.comparators.pandas_ import get_type_comparator_mapping
>>> get_type_comparator_mapping()
{<class 'pandas.core.series.Series'>: PandasSeriesEqualityComparator()}
{<class 'pandas.core.frame.DataFrame'>: PandasDataFrameEqualityComparator(),
<class 'pandas.core.series.Series'>: PandasSeriesEqualityComparator()}
```
"""
if not is_pandas_available():
return {}
return {pandas.Series: PandasSeriesEqualityComparator()}
return {
pandas.DataFrame: PandasDataFrameEqualityComparator(),
pandas.Series: PandasSeriesEqualityComparator(),
}
8 changes: 6 additions & 2 deletions src/coola/equality/handlers/pandas_.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ class PandasDataFrameEqualHandler(BaseEqualityHandler):
>>> config = EqualityConfig(tester=EqualityTester())
>>> handler = PandasDataFrameEqualHandler()
>>> handler.handle(
... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 5]), config
... pandas.DataFrame({"col": [1, 2, 3]}),
... pandas.DataFrame({"col": [1, 2, 3]}),
... config,
... )
True
>>> handler.handle(
... pandas.DataFrame([1, 2, 3, 4, 5]), pandas.DataFrame([1, 2, 3, 4, 0]), config
... pandas.DataFrame({"col": [1, 2, 3]}),
... pandas.DataFrame({"col": [1, 2, 4]}),
... config,
... )
False
Expand Down
Loading

0 comments on commit 28ed765

Please sign in to comment.