Skip to content

Commit

Permalink
Explicitly implement protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 23, 2023
1 parent 67d5328 commit 7389baf
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 20 deletions.
4 changes: 2 additions & 2 deletions py-polars/polars/interchange/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from polars.interchange.protocol import DlpackDeviceType, DtypeKind
from polars.interchange.protocol import BufferObject, DlpackDeviceType, DtypeKind
from polars.interchange.utils import polars_dtype_to_dtype

if TYPE_CHECKING:
Expand All @@ -11,7 +11,7 @@
from polars import Series


class PolarsBuffer:
class PolarsBuffer(BufferObject):
"""
A buffer object backed by a Polars Series consisting of a single chunk.
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/interchange/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from polars.datatypes import Categorical
from polars.interchange.buffer import PolarsBuffer
from polars.interchange.protocol import ColumnNullType, DtypeKind, Endianness
from polars.interchange.protocol import (
ColumnNullType,
ColumnObject,
DtypeKind,
Endianness,
)
from polars.interchange.utils import polars_dtype_to_dtype
from polars.utils._wrap import wrap_s

Expand All @@ -16,7 +21,7 @@
from polars.interchange.protocol import CategoricalDescription, ColumnBuffers, Dtype


class PolarsColumn:
class PolarsColumn(ColumnObject):
"""
A column object backed by a Polars Series.
Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/interchange/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

from polars.interchange.column import PolarsColumn
from polars.interchange.protocol import DataFrameObject

if TYPE_CHECKING:
from collections.abc import Iterator
Expand All @@ -13,7 +14,7 @@
from polars import DataFrame


class PolarsDataFrame:
class PolarsDataFrame(DataFrameObject):
"""
A dataframe object backed by a Polars DataFrame.
Expand All @@ -27,6 +28,8 @@ class PolarsDataFrame:
"""

version = 0

def __init__(self, df: DataFrame, *, allow_copy: bool = True):
self._df = df
self._allow_copy = allow_copy
Expand Down Expand Up @@ -124,7 +127,7 @@ def get_columns(self) -> Iterator[PolarsColumn]:

def select_columns(self, indices: Sequence[int]) -> PolarsDataFrame:
"""
Create a new DataFrame by selecting a subset of columns by index.
Create a new dataframe by selecting a subset of columns by index.
Parameters
----------
Expand Down
142 changes: 128 additions & 14 deletions py-polars/polars/interchange/protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from __future__ import annotations

from enum import IntEnum
from typing import TYPE_CHECKING, Literal, Tuple, TypedDict
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Literal,
Protocol,
Sequence,
Tuple,
TypedDict,
)

if TYPE_CHECKING:
import sys
Expand All @@ -15,6 +24,19 @@
from typing_extensions import TypeAlias


class DlpackDeviceType(IntEnum):
"""Integer enum for device type codes matching DLPack."""

CPU = 1
CUDA = 2
CPU_PINNED = 3
OPENCL = 4
VULKAN = 7
METAL = 8
VPI = 9
ROCM = 10


class DtypeKind(IntEnum):
"""
Integer enum for data types.
Expand Down Expand Up @@ -105,23 +127,115 @@ class CategoricalDescription(TypedDict):
categories: PolarsColumn


class DlpackDeviceType(IntEnum):
"""Integer enum for device type codes matching DLPack."""

CPU = 1
CUDA = 2
CPU_PINNED = 3
OPENCL = 4
VULKAN = 7
METAL = 8
VPI = 9
ROCM = 10


class Endianness:
"""Enum indicating the byte-order of a data type."""

LITTLE = "<"
BIG = ">"
NATIVE = "="
NA = "|"


class BufferObject(Protocol):
"""Interchange buffer object."""

@property
def bufsize(self) -> int:
"""Buffer size in bytes."""

@property
def ptr(self) -> int:
"""Pointer to start of the buffer as an integer."""

def __dlpack__(self) -> Any:
"""Represent this structure as DLPack interface."""

def __dlpack_device__(self) -> tuple[DlpackDeviceType, int | None]:
"""Device type and device ID for where the data in the buffer resides."""


class ColumnObject(Protocol):
"""Interchange column object."""

def size(self) -> int:
"""Size of the column in elements."""

@property
def offset(self) -> int:
"""Offset of the first element with respect to the start of the underlying buffer.""" # noqa: W505

@property
def dtype(self) -> Dtype:
"""Data type of the column."""

@property
def describe_categorical(self) -> CategoricalDescription:
"""Description of the categorical data type of the column."""

@property
def describe_null(self) -> tuple[ColumnNullType, Any]:
"""Description of the null representation the column uses."""

@property
def null_count(self) -> int | None:
"""Number of null elements, if known."""

@property
def metadata(self) -> dict[str, Any]:
"""The metadata for the column."""

def num_chunks(self) -> int:
"""Return the number of chunks the column consists of."""

def get_chunks(self, n_chunks: int | None = None) -> Iterable[ColumnObject]:
"""Return an iterator yielding the column chunks."""

def get_buffers(self) -> ColumnBuffers:
"""Return a dictionary containing the underlying buffers."""


class DataFrameObject(Protocol):
"""Interchange dataframe object."""

@property
def version(self) -> int:
"""Version of the protocol."""

def __dataframe__(
self, nan_as_null: bool = False, allow_copy: bool = True
) -> DataFrameObject:
"""Construct a new dataframe object, potentially changing the parameters."""

@property
def metadata(self) -> dict[str, Any]:
"""The metadata for the dataframe."""

def num_columns(self) -> int:
"""Return the number of columns in the dataframe."""

def num_rows(self) -> int | None:
"""Return the number of rows in the dataframe, if available."""

def num_chunks(self) -> int:
"""Return the number of chunks the dataframe consists of.."""

def column_names(self) -> Iterable[str]:
"""Return the column names."""

def get_column(self, i: int) -> ColumnObject:
"""Return the column at the indicated position."""

def get_column_by_name(self, name: str) -> ColumnObject:
"""Return the column with the given name."""

def get_columns(self) -> Iterable[ColumnObject]:
"""Return an iterator yielding the columns."""

def select_columns(self, indices: Sequence[int]) -> DataFrameObject:
"""Create a new dataframe by selecting a subset of columns by index."""

def select_columns_by_name(self, names: Sequence[str]) -> DataFrameObject:
"""Create a new dataframe by selecting a subset of columns by name."""

def get_chunks(self, n_chunks: int | None = None) -> Iterable[DataFrameObject]:
"""Return an iterator yielding the chunks of the dataframe."""

0 comments on commit 7389baf

Please sign in to comment.