diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 3f643f57..b61b03ff 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -25,7 +25,10 @@ from pandas import ( ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import DataFrameGroupBy +from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, +) from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import Index @@ -1052,18 +1055,30 @@ class DataFrame(NDFrame, OpsMixin): errors: IgnoreRaise = ..., ) -> None: ... @overload - def groupby( + def groupby( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] self, by: Scalar, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[False] = ..., sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., ) -> DataFrameGroupBy[Scalar]: ... @overload + def groupby( + self, + by: Scalar, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[True] = True, + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy: ... + @overload def groupby( self, by: DatetimeIndex, diff --git a/pandas-stubs/core/groupby/groupby.pyi b/pandas-stubs/core/groupby/groupby.pyi index 75be9578..5e942306 100644 --- a/pandas-stubs/core/groupby/groupby.pyi +++ b/pandas-stubs/core/groupby/groupby.pyi @@ -236,7 +236,7 @@ class GroupBy(BaseGroupBy[NDFrameT]): @overload def size(self: GroupBy[Series]) -> Series[int]: ... @overload # return type depends on `as_index` for dataframe groupby - def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ... + def size(self: GroupBy[DataFrame]) -> DataFrame: ... @final def sum( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index 64198952..342ce50c 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1025,6 +1025,24 @@ def test_types_pivot_table() -> None: ) +def test_types_groupby_as_index() -> None: + df = pd.DataFrame({"a": [1, 2, 3]}) + check( + assert_type( + df.groupby("a", as_index=False).size(), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.groupby("a", as_index=True).size(), + "pd.Series[int]", + ), + pd.Series, + ) + + def test_types_groupby() -> None: df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]}) df.index.name = "ind" @@ -1048,7 +1066,7 @@ def test_types_groupby() -> None: df1: pd.DataFrame = df.groupby(by="col1").agg("sum") df2: pd.DataFrame = df.groupby(level="ind").aggregate("sum") - df3: pd.DataFrame = df.groupby(by="col1", sort=False, as_index=True).transform( + df3: pd.Series = df.groupby(by="col1", sort=False, as_index=True).transform( lambda x: x.max() ) df4: pd.DataFrame = df.groupby(by=["col1", "col2"]).count()