Skip to content

Commit

Permalink
GH203 Split groupby with as_index
Browse files Browse the repository at this point in the history
  • Loading branch information
loicdiridollou committed Oct 13, 2024
1 parent fecd8e9 commit d4ea91e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
21 changes: 18 additions & 3 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ from pandas import (
)
from pandas.core.arraylike import OpsMixin
from pandas.core.generic import NDFrame
from pandas.core.groupby.generic import DataFrameGroupBy
from pandas.core.groupby.generic import (
DataFrameGroupBy,
SeriesGroupBy,
)
from pandas.core.groupby.grouper import Grouper
from pandas.core.indexers import BaseIndexer
from pandas.core.indexes.base import Index
Expand Down Expand Up @@ -1052,18 +1055,30 @@ class DataFrame(NDFrame, OpsMixin):
errors: IgnoreRaise = ...,
) -> None: ...
@overload
def groupby(
def groupby( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
self,
by: Scalar,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: _bool = ...,
as_index: Literal[False] = ...,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> DataFrameGroupBy[Scalar]: ...
@overload
def groupby(
self,
by: Scalar,
axis: AxisIndex | NoDefault = ...,
level: IndexLabel | None = ...,
as_index: Literal[True] = True,
sort: _bool = ...,
group_keys: _bool = ...,
observed: _bool | NoDefault = ...,
dropna: _bool = ...,
) -> SeriesGroupBy: ...
@overload
def groupby(
self,
by: DatetimeIndex,
Expand Down
2 changes: 1 addition & 1 deletion pandas-stubs/core/groupby/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
@overload
def size(self: GroupBy[Series]) -> Series[int]: ...
@overload # return type depends on `as_index` for dataframe groupby
def size(self: GroupBy[DataFrame]) -> DataFrame | Series[int]: ...
def size(self: GroupBy[DataFrame]) -> DataFrame: ...
@final
def sum(
self,
Expand Down
20 changes: 19 additions & 1 deletion tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,24 @@ def test_types_pivot_table() -> None:
)


def test_types_groupby_as_index() -> None:
df = pd.DataFrame({"a": [1, 2, 3]})
check(
assert_type(
df.groupby("a", as_index=False).size(),
pd.DataFrame,
),
pd.DataFrame,
)
check(
assert_type(
df.groupby("a", as_index=True).size(),
"pd.Series[int]",
),
pd.Series,
)


def test_types_groupby() -> None:
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
df.index.name = "ind"
Expand All @@ -1048,7 +1066,7 @@ def test_types_groupby() -> None:

df1: pd.DataFrame = df.groupby(by="col1").agg("sum")
df2: pd.DataFrame = df.groupby(level="ind").aggregate("sum")
df3: pd.DataFrame = df.groupby(by="col1", sort=False, as_index=True).transform(
df3: pd.Series = df.groupby(by="col1", sort=False, as_index=True).transform(
lambda x: x.max()
)
df4: pd.DataFrame = df.groupby(by=["col1", "col2"]).count()
Expand Down

0 comments on commit d4ea91e

Please sign in to comment.