diff --git a/tests/dataframe/test_groupby_pytest.py b/tests/dataframe/test_groupby_pytest.py index e6be6c05..73ab836c 100644 --- a/tests/dataframe/test_groupby_pytest.py +++ b/tests/dataframe/test_groupby_pytest.py @@ -23,6 +23,17 @@ from tests.common import TestData +PANDAS_MAJOR_VERSION = int(pd.__version__.split(".")[0]) + + +# The mean absolute difference (mad) aggregation has been removed from +# pandas with major version 2: +# https://github.com/pandas-dev/pandas/issues/11787 +# To compare whether eland's version of it works, we need to implement +# it here ourselves. +def mad(x): + return abs(x - x.mean()).mean() + class TestGroupbyDataFrame(TestData): funcs = ["max", "min", "mean", "sum"] @@ -71,7 +82,7 @@ def test_groupby_aggregate_single_aggs(self, pd_agg): @pytest.mark.parametrize("dropna", [True, False]) @pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"]) def test_groupby_aggs_numeric_only_true(self, pd_agg, dropna): - # Pandas has numeric_only applicable for the above aggs with groupby only. + # Pandas has numeric_only applicable for the above aggs with groupby only. pd_flights = self.pd_flights().filter(self.filter_data) ed_flights = self.ed_flights().filter(self.filter_data) @@ -95,7 +106,14 @@ def test_groupby_aggs_mad_var_std(self, pd_agg, dropna): pd_flights = self.pd_flights().filter(self.filter_data) ed_flights = self.ed_flights().filter(self.filter_data) - pd_groupby = getattr(pd_flights.groupby("Cancelled", dropna=dropna), pd_agg)() + # The mad aggregation has been removed in Pandas 2, so we need to use + # our own implementation if we run the tests with Pandas 2 or higher + if PANDAS_MAJOR_VERSION >= 2 and pd_agg == "mad": + pd_groupby = pd_flights.groupby("Cancelled", dropna=dropna).aggregate(mad) + else: + pd_groupby = getattr( + pd_flights.groupby("Cancelled", dropna=dropna), pd_agg + )() ed_groupby = getattr(ed_flights.groupby("Cancelled", dropna=dropna), pd_agg)( numeric_only=True ) @@ -211,14 +229,20 @@ def test_groupby_dataframe_mad(self): pd_flights = self.pd_flights().filter(self.filter_data + ["DestCountry"]) ed_flights = self.ed_flights().filter(self.filter_data + ["DestCountry"]) - pd_mad = pd_flights.groupby("DestCountry").mad() + if PANDAS_MAJOR_VERSION < 2: + pd_mad = pd_flights.groupby("DestCountry").mad() + else: + pd_mad = pd_flights.groupby("DestCountry").aggregate(mad) ed_mad = ed_flights.groupby("DestCountry").mad() assert_index_equal(pd_mad.columns, ed_mad.columns) assert_index_equal(pd_mad.index, ed_mad.index) assert_series_equal(pd_mad.dtypes, ed_mad.dtypes) - pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"]) + if PANDAS_MAJOR_VERSION < 2: + pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"]) + else: + pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", mad]) ed_min_mad = ed_flights.groupby("DestCountry").aggregate(["min", "mad"]) assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)