Skip to content

Commit

Permalink
update types
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jul 12, 2023
1 parent b2586df commit 6e300a9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
2 changes: 1 addition & 1 deletion dataframe_api_compat/pandas_standard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def convert_to_standard_compliant_dataframe(df: pd.DataFrame) -> PandasDataFrame
return PandasDataFrame(df)


def convert_to_standard_compliant_column(df: pd.Series) -> PandasColumn:
def convert_to_standard_compliant_column(df: pd.Series[Any]) -> PandasColumn[Any]:
return PandasColumn(df)


Expand Down
24 changes: 16 additions & 8 deletions dataframe_api_compat/pandas_standard/pandas_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_rows(self, indices: Column[Any]) -> PandasColumn[DType]:

def slice_rows(
self, start: int | None, stop: int | None, step: int | None
) -> PandasColumn:
) -> PandasColumn[DType]:
if start is None:
start = 0
if stop is None:
Expand All @@ -86,7 +86,7 @@ def slice_rows(
step = 1
return PandasColumn(self.column.iloc[start:stop:step])

def get_rows_by_mask(self, mask: Column[Bool]) -> PandasColumn:
def get_rows_by_mask(self, mask: Column[Bool]) -> PandasColumn[DType]:
series = mask.column
self._validate_index(series.index)
return PandasColumn(self.column.loc[series])
Expand Down Expand Up @@ -210,10 +210,10 @@ def median(self, *, skip_nulls: bool = True) -> Any:
def mean(self, *, skip_nulls: bool = True) -> Any:
return self.column.mean()

def std(self, *, skip_nulls: bool = True) -> Any:
def std(self, *, correction: int | float = 1.0, skip_nulls: bool = True) -> Any:
return self.column.std()

def var(self, *, skip_nulls: bool = True) -> Any:
def var(self, *, correction: int | float = 1.0, skip_nulls: bool = True) -> Any:
return self.column.var()

def is_null(self) -> PandasColumn[Bool]:
Expand Down Expand Up @@ -319,12 +319,16 @@ def mean(self, *, skip_nulls: bool = True) -> PandasDataFrame:
self._validate_result(result)
return PandasDataFrame(result)

def std(self, *, skip_nulls: bool = True) -> PandasDataFrame:
def std(
self, *, correction: int | float = 1.0, skip_nulls: bool = True
) -> PandasDataFrame:
result = self.grouped.std()
self._validate_result(result)
return PandasDataFrame(result)

def var(self, *, skip_nulls: bool = True) -> PandasDataFrame:
def var(
self, *, correction: int | float = 1.0, skip_nulls: bool = True
) -> PandasDataFrame:
result = self.grouped.var()
self._validate_result(result)
return PandasDataFrame(result)
Expand Down Expand Up @@ -595,10 +599,14 @@ def median(self, *, skip_nulls: bool = True) -> PandasDataFrame:
def mean(self, *, skip_nulls: bool = True) -> PandasDataFrame:
return PandasDataFrame(self.dataframe.mean().to_frame().T)

def std(self, *, skip_nulls: bool = True) -> PandasDataFrame:
def std(
self, *, correction: int | float = 1.0, skip_nulls: bool = True
) -> PandasDataFrame:
return PandasDataFrame(self.dataframe.std().to_frame().T)

def var(self, *, skip_nulls: bool = True) -> PandasDataFrame:
def var(
self, *, correction: int | float = 1.0, skip_nulls: bool = True
) -> PandasDataFrame:
return PandasDataFrame(self.dataframe.var().to_frame().T)

def is_null(self, *, skip_nulls: bool = True) -> PandasDataFrame:
Expand Down
2 changes: 1 addition & 1 deletion dataframe_api_compat/polars_standard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ def convert_to_standard_compliant_dataframe(df: pl.DataFrame) -> PolarsDataFrame
return PolarsDataFrame(df)


def convert_to_standard_compliant_column(ser: pl.Series) -> PolarsColumn:
def convert_to_standard_compliant_column(ser: pl.Series) -> PolarsColumn[Any]:
return PolarsColumn(ser)
24 changes: 16 additions & 8 deletions dataframe_api_compat/polars_standard/polars_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_rows(self, indices: Column[Any]) -> PolarsColumn[DType]:

def slice_rows(
self, start: int | None, stop: int | None, step: int | None
) -> PolarsColumn:
) -> PolarsColumn[DType]:
if start is None:
start = 0
if stop is None:
Expand All @@ -71,7 +71,7 @@ def slice_rows(
step = 1
return PolarsColumn(self.column[start:stop:step])

def get_rows_by_mask(self, mask: Column[Bool]) -> PolarsDataFrame:
def get_rows_by_mask(self, mask: Column[Bool]) -> PolarsColumn[DType]:
name = self.column.name
return PolarsColumn(self.column.to_frame().filter(mask.column)[name])

Expand Down Expand Up @@ -121,10 +121,10 @@ def mean(self, *, skip_nulls: bool = True) -> Any:
def median(self, *, skip_nulls: bool = True) -> Any:
return self.column.median()

def std(self, *, skip_nulls: bool = True) -> Any:
def std(self, *, correction: int | float = 1.0, skip_nulls: bool = True) -> Any:
return self.column.std()

def var(self, *, skip_nulls: bool = True) -> Any:
def var(self, *, correction: int | float = 1.0, skip_nulls: bool = True) -> Any:
return self.column.var()

def __eq__( # type: ignore[override]
Expand Down Expand Up @@ -272,11 +272,15 @@ def mean(self, skip_nulls: bool = True) -> PolarsDataFrame:
result = self.df.groupby(self.keys).agg(pl.col("*").mean())
return PolarsDataFrame(result)

def std(self, skip_nulls: bool = True) -> PolarsDataFrame:
def std(
self, correction: int | float = 1.0, skip_nulls: bool = True
) -> PolarsDataFrame:
result = self.df.groupby(self.keys).agg(pl.col("*").std())
return PolarsDataFrame(result)

def var(self, skip_nulls: bool = True) -> PolarsDataFrame:
def var(
self, correction: int | float = 1.0, skip_nulls: bool = True
) -> PolarsDataFrame:
result = self.df.groupby(self.keys).agg(pl.col("*").var())
return PolarsDataFrame(result)

Expand Down Expand Up @@ -486,10 +490,14 @@ def mean(self, *, skip_nulls: bool = True) -> PolarsDataFrame:
def median(self, *, skip_nulls: bool = True) -> PolarsDataFrame:
return PolarsDataFrame(self.dataframe.select(pl.col("*").median()))

def std(self, *, skip_nulls: bool = True) -> PolarsDataFrame:
def std(
self, *, correction: int | float = 1.0, skip_nulls: bool = True
) -> PolarsDataFrame:
return PolarsDataFrame(self.dataframe.select(pl.col("*").std()))

def var(self, *, skip_nulls: bool = True) -> PolarsDataFrame:
def var(
self, *, correction: int | float = 1.0, skip_nulls: bool = True
) -> PolarsDataFrame:
return PolarsDataFrame(self.dataframe.select(pl.col("*").var()))

def sorted_indices(
Expand Down

0 comments on commit 6e300a9

Please sign in to comment.