diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index e98acabeeb3a69..b67f4899356052 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -1239,6 +1239,7 @@ Groupby/Resample/Rolling - Bug in :meth:`SeriesGroupBy.mean` when values were integral but could not fit inside of int64, overflowing instead. (:issue:`22487`) - :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`) - Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`) +- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/window.py b/pandas/core/window.py index 7d48967602bc15..5256532a318708 100644 --- a/pandas/core/window.py +++ b/pandas/core/window.py @@ -1866,12 +1866,25 @@ def _constructor(self): return Expanding def _get_window(self, other=None): - obj = self._selected_obj - if other is None: - return (max(len(obj), self.min_periods) if self.min_periods - else len(obj)) - return (max((len(obj) + len(obj)), self.min_periods) - if self.min_periods else (len(obj) + len(obj))) + """ + Get the window length over which to perform some operation. + + Parameters + ---------- + other : object, default None + The other object that is involved in the operation. + Such an object is involved for operations like covariance. + + Returns + ------- + window : int + The window length. + """ + axis = self.obj._get_axis(self.axis) + length = len(axis) + (other is not None) * len(axis) + + other = self.min_periods or -1 + return max(length, other) _agg_doc = dedent(""" Examples diff --git a/pandas/tests/test_window.py b/pandas/tests/test_window.py index 4b0c4d581a0081..c7cd04deac6c89 100644 --- a/pandas/tests/test_window.py +++ b/pandas/tests/test_window.py @@ -627,6 +627,25 @@ def test_iter_raises(self, klass): with pytest.raises(NotImplementedError): iter(obj.rolling(2)) + def test_rolling_axis(self, axis_frame): + # see gh-23372. + df = DataFrame(np.ones((10, 20))) + axis = df._get_axis_number(axis_frame) + + if axis == 0: + expected = DataFrame({ + i: [np.nan] * 2 + [3.0] * 8 + for i in range(20) + }) + else: + # axis == 1 + expected = DataFrame([ + [np.nan] * 2 + [3.0] * 18 + ] * 10) + + result = df.rolling(3, axis=axis_frame).sum() + tm.assert_frame_equal(result, expected) + class TestExpanding(Base): @@ -714,6 +733,25 @@ def test_iter_raises(self, klass): with pytest.raises(NotImplementedError): iter(obj.expanding(2)) + def test_expanding_axis(self, axis_frame): + # see gh-23372. + df = DataFrame(np.ones((10, 20))) + axis = df._get_axis_number(axis_frame) + + if axis == 0: + expected = DataFrame({ + i: [np.nan] * 2 + [float(j) for j in range(3, 11)] + for i in range(20) + }) + else: + # axis == 1 + expected = DataFrame([ + [np.nan] * 2 + [float(i) for i in range(3, 21)] + ] * 10) + + result = df.expanding(3, axis=axis_frame).sum() + tm.assert_frame_equal(result, expected) + class TestEWM(Base):