Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Explicitly implement Protocol for interchange classes #10688

Merged
merged 1 commit into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py-polars/polars/interchange/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from polars.interchange.protocol import DlpackDeviceType, DtypeKind
from polars.interchange.protocol import Buffer, DlpackDeviceType, DtypeKind
from polars.interchange.utils import polars_dtype_to_dtype

if TYPE_CHECKING:
Expand All @@ -11,7 +11,7 @@
from polars import Series


class PolarsBuffer:
class PolarsBuffer(Buffer):
"""
A buffer object backed by a Polars Series consisting of a single chunk.

Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/interchange/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from polars.datatypes import Categorical
from polars.interchange.buffer import PolarsBuffer
from polars.interchange.protocol import ColumnNullType, DtypeKind, Endianness
from polars.interchange.protocol import Column, ColumnNullType, DtypeKind, Endianness
from polars.interchange.utils import polars_dtype_to_dtype
from polars.utils._wrap import wrap_s

Expand All @@ -16,7 +16,7 @@
from polars.interchange.protocol import CategoricalDescription, ColumnBuffers, Dtype


class PolarsColumn:
class PolarsColumn(Column):
"""
A column object backed by a Polars Series.

Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/interchange/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

from polars.interchange.column import PolarsColumn
from polars.interchange.protocol import DataFrame as InterchangeDataFrame

if TYPE_CHECKING:
from collections.abc import Iterator
Expand All @@ -13,7 +14,7 @@
from polars import DataFrame


class PolarsDataFrame:
class PolarsDataFrame(InterchangeDataFrame):
"""
A dataframe object backed by a Polars DataFrame.

Expand All @@ -27,6 +28,8 @@ class PolarsDataFrame:

"""

version = 0

def __init__(self, df: DataFrame, *, allow_copy: bool = True):
self._df = df
self._allow_copy = allow_copy
Expand Down Expand Up @@ -124,7 +127,7 @@ def get_columns(self) -> Iterator[PolarsColumn]:

def select_columns(self, indices: Sequence[int]) -> PolarsDataFrame:
"""
Create a new DataFrame by selecting a subset of columns by index.
Create a new dataframe by selecting a subset of columns by index.

Parameters
----------
Expand Down
142 changes: 128 additions & 14 deletions py-polars/polars/interchange/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from __future__ import annotations

from enum import IntEnum
from typing import TYPE_CHECKING, Literal, Tuple, TypedDict
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
Protocol,
Sequence,
Tuple,
TypedDict,
)

if TYPE_CHECKING:
import sys
Expand All @@ -15,6 +24,19 @@
from typing_extensions import TypeAlias


class DlpackDeviceType(IntEnum):
"""Integer enum for device type codes matching DLPack."""

CPU = 1
CUDA = 2
CPU_PINNED = 3
OPENCL = 4
VULKAN = 7
METAL = 8
VPI = 9
ROCM = 10


class DtypeKind(IntEnum):
"""
Integer enum for data types.
Expand Down Expand Up @@ -105,23 +127,115 @@ class CategoricalDescription(TypedDict):
categories: PolarsColumn


class DlpackDeviceType(IntEnum):
"""Integer enum for device type codes matching DLPack."""

CPU = 1
CUDA = 2
CPU_PINNED = 3
OPENCL = 4
VULKAN = 7
METAL = 8
VPI = 9
ROCM = 10


class Endianness:
"""Enum indicating the byte-order of a data type."""

LITTLE = "<"
BIG = ">"
NATIVE = "="
NA = "|"


class Buffer(Protocol):
"""Interchange buffer object."""

@property
def bufsize(self) -> int:
"""Buffer size in bytes."""

@property
def ptr(self) -> int:
"""Pointer to start of the buffer as an integer."""

def __dlpack__(self) -> Any:
"""Represent this structure as DLPack interface."""

def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]:
"""Device type and device ID for where the data in the buffer resides."""


class Column(Protocol):
"""Interchange column object."""

def size(self) -> int:
"""Size of the column in elements."""

@property
def offset(self) -> int:
"""Offset of the first element with respect to the start of the underlying buffer.""" # noqa: W505

@property
def dtype(self) -> Dtype:
"""Data type of the column."""

@property
def describe_categorical(self) -> CategoricalDescription:
"""Description of the categorical data type of the column."""

@property
def describe_null(self) -> tuple[ColumnNullType, Any]:
"""Description of the null representation the column uses."""

@property
def null_count(self) -> int | None:
"""Number of null elements, if known."""

@property
def metadata(self) -> dict[str, Any]:
"""The metadata for the column."""

def num_chunks(self) -> int:
"""Return the number of chunks the column consists of."""

def get_chunks(self, n_chunks: int | None = None) -> Iterable[Column]:
"""Return an iterator yielding the column chunks."""

def get_buffers(self) -> ColumnBuffers:
"""Return a dictionary containing the underlying buffers."""


class DataFrame(Protocol):
"""Interchange dataframe object."""

@property
def version(self) -> int:
"""Version of the protocol."""

def __dataframe__(
self, nan_as_null: bool = False, allow_copy: bool = True
) -> DataFrame:
"""Construct a new dataframe object, potentially changing the parameters."""

@property
def metadata(self) -> dict[str, Any]:
"""The metadata for the dataframe."""

def num_columns(self) -> int:
"""Return the number of columns in the dataframe."""

def num_rows(self) -> int | None:
"""Return the number of rows in the dataframe, if available."""

def num_chunks(self) -> int:
"""Return the number of chunks the dataframe consists of.."""

def column_names(self) -> Iterable[str]:
"""Return the column names."""

def get_column(self, i: int) -> Column:
"""Return the column at the indicated position."""

def get_column_by_name(self, name: str) -> Column:
"""Return the column with the given name."""

def get_columns(self) -> Iterable[Column]:
"""Return an iterator yielding the columns."""

def select_columns(self, indices: Sequence[int]) -> DataFrame:
"""Create a new dataframe by selecting a subset of columns by index."""

def select_columns_by_name(self, names: Sequence[str]) -> DataFrame:
"""Create a new dataframe by selecting a subset of columns by name."""

def get_chunks(self, n_chunks: int | None = None) -> Iterable[DataFrame]:
"""Return an iterator yielding the chunks of the dataframe."""
Loading