Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make StructColumn.__init__ strict #16467

Merged
merged 6 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,22 +1635,23 @@ def build_column(
)
elif isinstance(dtype, IntervalDtype):
return cudf.core.column.IntervalColumn(
data=None,
size=size, # type: ignore[arg-type]
vyasr marked this conversation as resolved.
Show resolved Hide resolved
dtype=dtype,
mask=mask,
size=size,
offset=offset,
children=children,
null_count=null_count,
children=children, # type: ignore[arg-type]
)
elif isinstance(dtype, StructDtype):
return cudf.core.column.StructColumn(
data=data,
dtype=dtype,
data=None,
size=size, # type: ignore[arg-type]
offset=offset,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
children=children, # type: ignore[arg-type]
)
elif isinstance(dtype, cudf.Decimal64Dtype):
return cudf.core.column.Decimal64Column(
Expand Down
71 changes: 47 additions & 24 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,46 @@
from cudf.core.dtypes import IntervalDtype

if TYPE_CHECKING:
from typing_extensions import Self

from cudf._typing import ScalarLike
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase


class IntervalColumn(StructColumn):
def __init__(
self,
dtype,
mask=None,
size=None,
offset=0,
null_count=None,
children=(),
data: None,
size: int,
dtype: IntervalDtype,
mask: Buffer | None = None,
offset: int = 0,
null_count: int | None = None,
children: tuple[ColumnBase, ColumnBase] = (), # type: ignore[assignment]
):
if len(children) != 2:
raise ValueError(
"children must be a tuple of two columns (left edges, right edges)."
)
super().__init__(
data=None,
data=data,
size=size,
dtype=dtype,
mask=mask,
size=size,
offset=offset,
null_count=null_count,
children=children,
)

@staticmethod
def _validate_dtype_instance(dtype: IntervalDtype) -> IntervalDtype:
if not isinstance(dtype, IntervalDtype):
raise ValueError("dtype must be a IntervalDtype.")
return dtype

@classmethod
def from_arrow(cls, data):
def from_arrow(cls, data: pa.Array) -> Self:
new_col = super().from_arrow(data.storage)
size = len(data)
dtype = IntervalDtype.from_arrow(data.type)
Expand All @@ -48,16 +62,17 @@ def from_arrow(cls, data):
null_count = data.null_count
children = new_col.children

return IntervalColumn(
return cls(
data=None,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
children=children, # type: ignore[arg-type]
)

def to_arrow(self):
def to_arrow(self) -> pa.Array:
typ = self.dtype.to_arrow()
struct_arrow = super().to_arrow()
if len(struct_arrow) == 0:
Expand All @@ -67,30 +82,36 @@ def to_arrow(self):
return pa.ExtensionArray.from_storage(typ, struct_arrow)

@classmethod
def from_struct_column(cls, struct_column: StructColumn, closed="right"):
def from_struct_column(
cls,
struct_column: StructColumn,
closed: Literal["left", "right", "both", "neither"] = "right",
) -> Self:
first_field_name = next(iter(struct_column.dtype.fields.keys()))
return IntervalColumn(
return cls(
data=None,
size=struct_column.size,
dtype=IntervalDtype(
struct_column.dtype.fields[first_field_name], closed
),
mask=struct_column.base_mask,
offset=struct_column.offset,
null_count=struct_column.null_count,
children=struct_column.base_children,
children=struct_column.base_children, # type: ignore[arg-type]
)

def copy(self, deep=True):
def copy(self, deep: bool = True) -> Self:
struct_copy = super().copy(deep=deep)
return IntervalColumn(
return IntervalColumn( # type: ignore[return-value]
data=None,
size=struct_copy.size,
dtype=IntervalDtype(
struct_copy.dtype.fields["left"], self.dtype.closed
),
mask=struct_copy.base_mask,
offset=struct_copy.offset,
null_count=struct_copy.null_count,
children=struct_copy.base_children,
children=struct_copy.base_children, # type: ignore[arg-type]
)

@property
Expand Down Expand Up @@ -138,25 +159,27 @@ def overlaps(other) -> ColumnBase:

def set_closed(
self, closed: Literal["left", "right", "both", "neither"]
) -> IntervalColumn:
return IntervalColumn(
) -> Self:
return IntervalColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=IntervalDtype(self.dtype.fields["left"], closed),
mask=self.base_mask,
offset=self.offset,
null_count=self.null_count,
children=self.base_children,
children=self.base_children, # type: ignore[arg-type]
)

def as_interval_column(self, dtype):
def as_interval_column(self, dtype: IntervalDtype) -> Self: # type: ignore[override]
if isinstance(dtype, IntervalDtype):
return IntervalColumn(
return IntervalColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=dtype,
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=tuple(
children=tuple( # type: ignore[arg-type]
child.astype(dtype.subtype) for child in self.children
),
)
Expand Down
50 changes: 41 additions & 9 deletions python/cudf/cudf/core/column/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from cudf.core.missing import NA

if TYPE_CHECKING:
from typing_extensions import Self

from cudf._typing import Dtype
from cudf.core.buffer import Buffer


class StructColumn(ColumnBase):
Expand All @@ -23,10 +26,39 @@ class StructColumn(ColumnBase):

Every column has n children, where n is
the number of fields in the Struct Dtype.

"""

dtype: StructDtype
def __init__(
self,
data: None,
size: int,
dtype: StructDtype,
mask: Buffer | None = None,
offset: int = 0,
null_count: int | None = None,
children: tuple[ColumnBase, ...] = (),
):
if data is not None:
raise ValueError("data must be None.")
dtype = self._validate_dtype_instance(dtype)
super().__init__(
data=data,
size=size,
dtype=dtype,
mask=mask,
offset=offset,
null_count=null_count,
children=children,
)

@staticmethod
def _validate_dtype_instance(dtype: StructDtype) -> StructDtype:
# IntervalDtype is a subclass of StructDtype, so compare types exactly
if type(dtype) is not StructDtype:
vyasr marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"{type(dtype).__name__} must be a StructDtype exactly."
)
return dtype

@property
def base_size(self):
Expand All @@ -35,7 +67,7 @@ def base_size(self):
else:
return self.size + self.offset

def to_arrow(self):
def to_arrow(self) -> pa.Array:
children = [
pa.nulls(len(child))
if len(child) == child.null_count
Expand All @@ -50,7 +82,7 @@ def to_arrow(self):
}
)

if self.nullable:
if self.mask is not None:
buffers = (pa.py_buffer(self.mask.memoryview()),)
else:
buffers = (None,)
Expand All @@ -73,7 +105,7 @@ def to_pandas(
return pd.Index(self.to_arrow().tolist(), dtype="object")

@cached_property
def memory_usage(self):
def memory_usage(self) -> int:
n = 0
if self.nullable:
n += cudf._lib.null_mask.bitmask_allocation_size_bytes(self.size)
Expand All @@ -99,23 +131,23 @@ def __setitem__(self, key, value):
value = cudf.Scalar(value, self.dtype)
super().__setitem__(key, value)

def copy(self, deep=True):
def copy(self, deep: bool = True) -> Self:
# Since struct columns are immutable, both deep and
# shallow copies share the underlying device data and mask.
result = super().copy(deep=False)
if deep:
result = result._rename_fields(self.dtype.fields.keys())
return result

def _rename_fields(self, names):
def _rename_fields(self, names) -> Self:
"""
Return a StructColumn with the same field values as this StructColumn,
but with the field names equal to `names`.
"""
dtype = cudf.core.dtypes.StructDtype(
dtype = StructDtype(
{name: col.dtype for name, col in zip(names, self.children)}
)
return StructColumn(
return StructColumn( # type: ignore[return-value]
data=None,
size=self.size,
dtype=dtype,
Expand Down
6 changes: 5 additions & 1 deletion python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3354,6 +3354,7 @@ def interval_range(
return IntervalIndex(data, closed=closed, name=name)

interval_col = IntervalColumn(
data=None,
dtype=IntervalDtype(left_col.dtype, closed),
size=len(left_col),
children=(left_col, right_col),
Expand Down Expand Up @@ -3425,6 +3426,7 @@ def __init__(
elif isinstance(data.dtype, (pd.IntervalDtype, IntervalDtype)):
data = np.array([], dtype=data.dtype.subtype)
interval_col = IntervalColumn(
None,
dtype=IntervalDtype(data.dtype, closed),
size=len(data),
children=(as_column(data), as_column(data)),
Expand All @@ -3436,12 +3438,13 @@ def __init__(
if copy:
col = col.copy()
interval_col = IntervalColumn(
data=None,
dtype=IntervalDtype(col.dtype.subtype, closed),
mask=col.mask,
size=col.size,
offset=col.offset,
null_count=col.null_count,
children=col.children,
children=col.children, # type: ignore[arg-type]
)

if dtype:
Expand Down Expand Up @@ -3517,6 +3520,7 @@ def from_breaks(
)

interval_col = IntervalColumn(
data=None,
dtype=IntervalDtype(left_col.dtype, closed),
size=len(left_col),
children=(left_col, right_col),
Expand Down
Loading