From 64be3180b407a994f5906faf6ab3233f0921ab26 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 23 Aug 2023 15:29:52 +0200 Subject: [PATCH] Explicitly implement protocol --- py-polars/polars/interchange/buffer.py | 4 +- py-polars/polars/interchange/column.py | 4 +- py-polars/polars/interchange/dataframe.py | 7 +- py-polars/polars/interchange/protocol.py | 142 +++++++++++++++++++--- 4 files changed, 137 insertions(+), 20 deletions(-) diff --git a/py-polars/polars/interchange/buffer.py b/py-polars/polars/interchange/buffer.py index 5ee3b55d7db6..46c6bf12dc8f 100644 --- a/py-polars/polars/interchange/buffer.py +++ b/py-polars/polars/interchange/buffer.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from polars.interchange.protocol import DlpackDeviceType, DtypeKind +from polars.interchange.protocol import Buffer, DlpackDeviceType, DtypeKind from polars.interchange.utils import polars_dtype_to_dtype if TYPE_CHECKING: @@ -11,7 +11,7 @@ from polars import Series -class PolarsBuffer: +class PolarsBuffer(Buffer): """ A buffer object backed by a Polars Series consisting of a single chunk. diff --git a/py-polars/polars/interchange/column.py b/py-polars/polars/interchange/column.py index 8cf81b0b33c8..c7a945b1977f 100644 --- a/py-polars/polars/interchange/column.py +++ b/py-polars/polars/interchange/column.py @@ -4,7 +4,7 @@ from polars.datatypes import Categorical from polars.interchange.buffer import PolarsBuffer -from polars.interchange.protocol import ColumnNullType, DtypeKind, Endianness +from polars.interchange.protocol import Column, ColumnNullType, DtypeKind, Endianness from polars.interchange.utils import polars_dtype_to_dtype from polars.utils._wrap import wrap_s @@ -16,7 +16,7 @@ from polars.interchange.protocol import CategoricalDescription, ColumnBuffers, Dtype -class PolarsColumn: +class PolarsColumn(Column): """ A column object backed by a Polars Series. diff --git a/py-polars/polars/interchange/dataframe.py b/py-polars/polars/interchange/dataframe.py index 2d43a1353901..56ed4337d6f0 100644 --- a/py-polars/polars/interchange/dataframe.py +++ b/py-polars/polars/interchange/dataframe.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from polars.interchange.column import PolarsColumn +from polars.interchange.protocol import DataFrame as InterchangeDataFrame if TYPE_CHECKING: from collections.abc import Iterator @@ -13,7 +14,7 @@ from polars import DataFrame -class PolarsDataFrame: +class PolarsDataFrame(InterchangeDataFrame): """ A dataframe object backed by a Polars DataFrame. @@ -27,6 +28,8 @@ class PolarsDataFrame: """ + version = 0 + def __init__(self, df: DataFrame, *, allow_copy: bool = True): self._df = df self._allow_copy = allow_copy @@ -124,7 +127,7 @@ def get_columns(self) -> Iterator[PolarsColumn]: def select_columns(self, indices: Sequence[int]) -> PolarsDataFrame: """ - Create a new DataFrame by selecting a subset of columns by index. + Create a new dataframe by selecting a subset of columns by index. Parameters ---------- diff --git a/py-polars/polars/interchange/protocol.py b/py-polars/polars/interchange/protocol.py index de51804c2f63..4d7a85bfe83b 100644 --- a/py-polars/polars/interchange/protocol.py +++ b/py-polars/polars/interchange/protocol.py @@ -1,7 +1,16 @@ from __future__ import annotations from enum import IntEnum -from typing import TYPE_CHECKING, Literal, Tuple, TypedDict +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Literal, + Protocol, + Sequence, + Tuple, + TypedDict, +) if TYPE_CHECKING: import sys @@ -15,6 +24,19 @@ from typing_extensions import TypeAlias +class DlpackDeviceType(IntEnum): + """Integer enum for device type codes matching DLPack.""" + + CPU = 1 + CUDA = 2 + CPU_PINNED = 3 + OPENCL = 4 + VULKAN = 7 + METAL = 8 + VPI = 9 + ROCM = 10 + + class DtypeKind(IntEnum): """ Integer enum for data types. @@ -105,19 +127,6 @@ class CategoricalDescription(TypedDict): categories: PolarsColumn -class DlpackDeviceType(IntEnum): - """Integer enum for device type codes matching DLPack.""" - - CPU = 1 - CUDA = 2 - CPU_PINNED = 3 - OPENCL = 4 - VULKAN = 7 - METAL = 8 - VPI = 9 - ROCM = 10 - - class Endianness: """Enum indicating the byte-order of a data type.""" @@ -125,3 +134,108 @@ class Endianness: BIG = ">" NATIVE = "=" NA = "|" + + +class Buffer(Protocol): + """Interchange buffer object.""" + + @property + def bufsize(self) -> int: + """Buffer size in bytes.""" + + @property + def ptr(self) -> int: + """Pointer to start of the buffer as an integer.""" + + def __dlpack__(self) -> Any: + """Represent this structure as DLPack interface.""" + + def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]: + """Device type and device ID for where the data in the buffer resides.""" + + +class Column(Protocol): + """Interchange column object.""" + + def size(self) -> int: + """Size of the column in elements.""" + + @property + def offset(self) -> int: + """Offset of the first element with respect to the start of the underlying buffer.""" # noqa: W505 + + @property + def dtype(self) -> Dtype: + """Data type of the column.""" + + @property + def describe_categorical(self) -> CategoricalDescription: + """Description of the categorical data type of the column.""" + + @property + def describe_null(self) -> tuple[ColumnNullType, Any]: + """Description of the null representation the column uses.""" + + @property + def null_count(self) -> int | None: + """Number of null elements, if known.""" + + @property + def metadata(self) -> dict[str, Any]: + """The metadata for the column.""" + + def num_chunks(self) -> int: + """Return the number of chunks the column consists of.""" + + def get_chunks(self, n_chunks: int | None = None) -> Iterable[Column]: + """Return an iterator yielding the column chunks.""" + + def get_buffers(self) -> ColumnBuffers: + """Return a dictionary containing the underlying buffers.""" + + +class DataFrame(Protocol): + """Interchange dataframe object.""" + + @property + def version(self) -> int: + """Version of the protocol.""" + + def __dataframe__( + self, nan_as_null: bool = False, allow_copy: bool = True + ) -> DataFrame: + """Construct a new dataframe object, potentially changing the parameters.""" + + @property + def metadata(self) -> dict[str, Any]: + """The metadata for the dataframe.""" + + def num_columns(self) -> int: + """Return the number of columns in the dataframe.""" + + def num_rows(self) -> int | None: + """Return the number of rows in the dataframe, if available.""" + + def num_chunks(self) -> int: + """Return the number of chunks the dataframe consists of..""" + + def column_names(self) -> Iterable[str]: + """Return the column names.""" + + def get_column(self, i: int) -> Column: + """Return the column at the indicated position.""" + + def get_column_by_name(self, name: str) -> Column: + """Return the column with the given name.""" + + def get_columns(self) -> Iterable[Column]: + """Return an iterator yielding the columns.""" + + def select_columns(self, indices: Sequence[int]) -> DataFrame: + """Create a new dataframe by selecting a subset of columns by index.""" + + def select_columns_by_name(self, names: Sequence[str]) -> DataFrame: + """Create a new dataframe by selecting a subset of columns by name.""" + + def get_chunks(self, n_chunks: int | None = None) -> Iterable[DataFrame]: + """Return an iterator yielding the chunks of the dataframe."""