Skip to content

Commit

Permalink
to_array_object
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Aug 1, 2023
1 parent c0cdcca commit 8e9cd9f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
Binary file modified .coverage
Binary file not shown.
30 changes: 30 additions & 0 deletions dataframe_api_compat/pandas_standard/pandas_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@

DType = TypeVar("DType")

_ARRAY_API_DTYPES = frozenset(
(
"bool",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
)
)

if TYPE_CHECKING:
from dataframe_api import (
DataFrame,
Expand Down Expand Up @@ -278,6 +294,13 @@ def cumulative_max(self, *, skip_nulls: bool = True) -> PandasColumn[DType]:
def cumulative_min(self, *, skip_nulls: bool = True) -> PandasColumn[DType]:
return PandasColumn(self.column.cummin())

def to_array_object(self, dtype: str):
if dtype not in _ARRAY_API_DTYPES:
raise ValueError(
f"Invalid dtype {dtype}. Expected one of {_ARRAY_API_DTYPES}"
)
return self.column.to_numpy(dtype=dtype)


class PandasGroupBy(GroupBy):
def __init__(self, df: pd.DataFrame, keys: Sequence[str]) -> None:
Expand Down Expand Up @@ -707,3 +730,10 @@ def fill_null(
) # type: ignore[assignment]
df[column] = col
return PandasDataFrame(df)

def to_array_object(self, dtype: str) -> np.ndarray:
if dtype not in _ARRAY_API_DTYPES:
raise ValueError(
f"Invalid dtype {dtype}. Expected one of {_ARRAY_API_DTYPES}"
)
return self.dataframe.to_numpy(dtype=dtype)
30 changes: 30 additions & 0 deletions dataframe_api_compat/polars_standard/polars_standard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import numpy as np
import dataframe_api_compat.polars_standard
import collections

Expand All @@ -14,6 +15,21 @@
)
import polars as pl

_ARRAY_API_DTYPES = frozenset(
(
"bool",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
)
)
DType = TypeVar("DType")

if TYPE_CHECKING:
Expand Down Expand Up @@ -282,6 +298,13 @@ def cumulative_max(self, *, skip_nulls: bool = True) -> PolarsColumn[DType]:
def cumulative_min(self, *, skip_nulls: bool = True) -> PolarsColumn[DType]:
return PolarsColumn(self.column.cummin())

def to_array_object(self, dtype: str) -> np.ndarray:
if dtype not in _ARRAY_API_DTYPES:
raise ValueError(
f"Invalid dtype {dtype}. Expected one of {_ARRAY_API_DTYPES}"
)
return self.column.to_numpy().astype(dtype)


class PolarsGroupBy(GroupBy):
def __init__(self, df: pl.DataFrame, keys: Sequence[str]) -> None:
Expand Down Expand Up @@ -639,3 +662,10 @@ def fill_null(
pl.col(col).fill_null(value) for col in column_names
)
return PolarsDataFrame(df)

def to_array_object(self, dtype: str) -> np.ndarray:
if dtype not in _ARRAY_API_DTYPES:
raise ValueError(
f"Invalid dtype {dtype}. Expected one of {_ARRAY_API_DTYPES}"
)
return self.dataframe.to_numpy().astype(dtype)

0 comments on commit 8e9cd9f

Please sign in to comment.