Skip to content

Commit

Permalink
Ensure cudf objects can astype to any type when empty (rapidsai#16106)
Browse files Browse the repository at this point in the history
pandas allows objects to `astype` to any other type if the object is empty. The PR mirrors that behavior for cudf.

This PR also more consistently uses `astype` instead of `as_*_column` and fixes a bug in `IntervalDtype.__eq__` discovered when writing a unit test for this bug.

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: rapidsai#16106
  • Loading branch information
mroeschke authored Jul 1, 2024
1 parent 51fb873 commit 5efd72f
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 63 deletions.
9 changes: 9 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,15 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
raise NotImplementedError()

def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
if len(self) == 0:
dtype = cudf.dtype(dtype)
if self.dtype == dtype:
if copy:
return self.copy()
else:
return self
else:
return column_empty(0, dtype=dtype, masked=self.nullable)
if copy:
col = self.copy()
else:
Expand Down
36 changes: 19 additions & 17 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def __contains__(self, item: ScalarLike) -> bool:
return False
elif ts.tzinfo is not None:
ts = ts.tz_convert(None)
return ts.to_numpy().astype("int64") in self.as_numerical_column(
"int64"
return ts.to_numpy().astype("int64") in cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
)

@functools.cached_property
Expand Down Expand Up @@ -503,9 +503,9 @@ def mean(
self, skipna=None, min_count: int = 0, dtype=np.float64
) -> ScalarLike:
return pd.Timestamp(
self.as_numerical_column("int64").mean(
skipna=skipna, min_count=min_count, dtype=dtype
),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).mean(skipna=skipna, min_count=min_count, dtype=dtype),
unit=self.time_unit,
).as_unit(self.time_unit)

Expand All @@ -517,15 +517,17 @@ def std(
ddof: int = 1,
) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").std(
cast("cudf.core.column.NumericalColumn", self.astype("int64")).std(
skipna=skipna, min_count=min_count, dtype=dtype, ddof=ddof
)
* _unit_to_nanoseconds_conversion[self.time_unit],
).as_unit(self.time_unit)

def median(self, skipna: bool | None = None) -> pd.Timestamp:
return pd.Timestamp(
self.as_numerical_column("int64").median(skipna=skipna),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).median(skipna=skipna),
unit=self.time_unit,
).as_unit(self.time_unit)

Expand All @@ -534,18 +536,18 @@ def cov(self, other: DatetimeColumn) -> float:
raise TypeError(
f"cannot perform cov with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").cov(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).cov(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def corr(self, other: DatetimeColumn) -> float:
if not isinstance(other, DatetimeColumn):
raise TypeError(
f"cannot perform corr with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").corr(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).corr(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def quantile(
self,
Expand All @@ -554,7 +556,7 @@ def quantile(
exact: bool,
return_scalar: bool,
) -> ColumnBase:
result = self.as_numerical_column("int64").quantile(
result = self.astype("int64").quantile(
q=q,
interpolation=interpolation,
exact=exact,
Expand Down Expand Up @@ -645,12 +647,12 @@ def indices_of(
) -> cudf.core.column.NumericalColumn:
value = column.as_column(
pd.to_datetime(value), dtype=self.dtype
).as_numerical_column("int64")
return self.as_numerical_column("int64").indices_of(value)
).astype("int64")
return self.astype("int64").indices_of(value)

@property
def is_unique(self) -> bool:
return self.as_numerical_column("int64").is_unique
return self.astype("int64").is_unique

def isin(self, values: Sequence) -> ColumnBase:
return cudf.core.tools.datetimes._isin_datetimelike(self, values)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def normalize_binop_value(self, other):
"Decimal columns only support binary operations with "
"integer numerical columns."
)
other = other.as_decimal_column(
other = other.astype(
self.dtype.__class__(self.dtype.__class__.MAX_PRECISION, 0)
)
elif not isinstance(other, DecimalBaseColumn):
Expand Down
26 changes: 11 additions & 15 deletions python/cudf/cudf/core/column/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import cudf
from cudf.core.column import StructColumn
from cudf.core.dtypes import CategoricalDtype, IntervalDtype
from cudf.core.dtypes import IntervalDtype


class IntervalColumn(StructColumn):
Expand Down Expand Up @@ -87,20 +87,16 @@ def copy(self, deep=True):

def as_interval_column(self, dtype):
if isinstance(dtype, IntervalDtype):
if isinstance(self.dtype, CategoricalDtype):
new_struct = self._get_decategorized_column()
return IntervalColumn.from_struct_column(new_struct)
else:
return IntervalColumn(
size=self.size,
dtype=dtype,
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=tuple(
child.astype(dtype.subtype) for child in self.children
),
)
return IntervalColumn(
size=self.size,
dtype=dtype,
mask=self.mask,
offset=self.offset,
null_count=self.null_count,
children=tuple(
child.astype(dtype.subtype) for child in self.children
),
)
else:
raise ValueError("dtype must be IntervalDtype")

Expand Down
34 changes: 19 additions & 15 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __contains__(self, item: DatetimeLikeScalar) -> bool:
# np.timedelta64 raises ValueError, hence `item`
# cannot exist in `self`.
return False
return item.view("int64") in self.as_numerical_column("int64")
return item.view("int64") in cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
)

@property
def values(self):
Expand All @@ -132,9 +134,7 @@ def to_arrow(self) -> pa.Array:
self.mask_array_view(mode="read").copy_to_host()
)
data = pa.py_buffer(
self.as_numerical_column("int64")
.data_array_view(mode="read")
.copy_to_host()
self.astype("int64").data_array_view(mode="read").copy_to_host()
)
pa_dtype = np_to_pa_dtype(self.dtype)
return pa.Array.from_buffers(
Expand Down Expand Up @@ -295,13 +295,17 @@ def as_timedelta_column(

def mean(self, skipna=None, dtype: Dtype = np.float64) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").mean(skipna=skipna, dtype=dtype),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).mean(skipna=skipna, dtype=dtype),
unit=self.time_unit,
).as_unit(self.time_unit)

def median(self, skipna: bool | None = None) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").median(skipna=skipna),
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).median(skipna=skipna),
unit=self.time_unit,
).as_unit(self.time_unit)

Expand All @@ -315,7 +319,7 @@ def quantile(
exact: bool,
return_scalar: bool,
) -> ColumnBase:
result = self.as_numerical_column("int64").quantile(
result = self.astype("int64").quantile(
q=q,
interpolation=interpolation,
exact=exact,
Expand All @@ -337,7 +341,7 @@ def sum(
# Since sum isn't overridden in Numerical[Base]Column, mypy only
# sees the signature from Reducible (which doesn't have the extra
# parameters from ColumnBase._reduce) so we have to ignore this.
self.as_numerical_column("int64").sum( # type: ignore
self.astype("int64").sum( # type: ignore
skipna=skipna, min_count=min_count, dtype=dtype
),
unit=self.time_unit,
Expand All @@ -351,7 +355,7 @@ def std(
ddof: int = 1,
) -> pd.Timedelta:
return pd.Timedelta(
self.as_numerical_column("int64").std(
cast("cudf.core.column.NumericalColumn", self.astype("int64")).std(
skipna=skipna, min_count=min_count, ddof=ddof, dtype=dtype
),
unit=self.time_unit,
Expand All @@ -362,18 +366,18 @@ def cov(self, other: TimeDeltaColumn) -> float:
raise TypeError(
f"cannot perform cov with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").cov(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).cov(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def corr(self, other: TimeDeltaColumn) -> float:
if not isinstance(other, TimeDeltaColumn):
raise TypeError(
f"cannot perform corr with types {self.dtype}, {other.dtype}"
)
return self.as_numerical_column("int64").corr(
other.as_numerical_column("int64")
)
return cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
).corr(cast("cudf.core.column.NumericalColumn", other.astype("int64")))

def components(self) -> dict[str, ColumnBase]:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,7 @@ def scatter_by_map(
if isinstance(map_index, cudf.core.column.StringColumn):
cat_index = cast(
cudf.core.column.CategoricalColumn,
map_index.as_categorical_column("category"),
map_index.astype("category"),
)
map_index = cat_index.codes
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ def to_pandas(self) -> pd.IntervalDtype:
def __eq__(self, other):
if isinstance(other, str):
# This means equality isn't transitive but mimics pandas
return other == self.name
return other in (self.name, str(self))
return (
type(self) == type(other)
and self.subtype == other.subtype
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def from_arrow(cls, data: pa.Table) -> Self:
# of column is 0 (i.e., empty) then we will have an
# int8 column in result._data[name] returned by libcudf,
# which needs to be type-casted to 'category' dtype.
result[name] = result[name].as_categorical_column("category")
result[name] = result[name].astype("category")
elif (
pandas_dtypes.get(name) == "empty"
and np_dtypes.get(name) == "object"
Expand All @@ -936,7 +936,7 @@ def from_arrow(cls, data: pa.Table) -> Self:
# is specified as 'empty' and np_dtypes as 'object',
# hence handling this special case to type-cast the empty
# float column to str column.
result[name] = result[name].as_string_column(cudf.dtype("str"))
result[name] = result[name].astype(cudf.dtype("str"))
elif name in data.column_names and isinstance(
data[name].type,
(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/indexing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def parse_row_iloc_indexer(key: Any, n: int) -> IndexingSpec:
else:
key = cudf.core.column.as_column(key)
if isinstance(key, cudf.core.column.CategoricalColumn):
key = key.as_numerical_column(key.codes.dtype)
key = key.astype(key.codes.dtype)
if is_bool_dtype(key.dtype):
return MaskIndexer(BooleanMask(key, n))
elif len(key) == 0:
Expand Down
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3107,10 +3107,12 @@ def value_counts(
# Pandas returns an IntervalIndex as the index of res
# this condition makes sure we do too if bins is given
if bins is not None and len(res) == len(res.index.categories):
int_index = IntervalColumn.as_interval_column(
res.index._column, res.index.categories.dtype
interval_col = IntervalColumn.from_struct_column(
res.index._column._get_decategorized_column()
)
res.index = cudf.IntervalIndex._from_data(
{res.index.name: interval_col}
)
res.index = int_index
res.name = result_name
return res

Expand Down
14 changes: 7 additions & 7 deletions python/cudf/cudf/core/tools/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def to_numeric(arg, errors="raise", downcast=None):
dtype = col.dtype

if is_datetime_dtype(dtype) or is_timedelta_dtype(dtype):
col = col.as_numerical_column(cudf.dtype("int64"))
col = col.astype(cudf.dtype("int64"))
elif isinstance(dtype, CategoricalDtype):
cat_dtype = col.dtype.type
if _is_non_decimal_numeric_dtype(cat_dtype):
col = col.as_numerical_column(cat_dtype)
col = col.astype(cat_dtype)
else:
try:
col = _convert_str_col(
Expand All @@ -146,8 +146,8 @@ def to_numeric(arg, errors="raise", downcast=None):
raise ValueError("Unrecognized datatype")

# str->float conversion may require lower precision
if col.dtype == cudf.dtype("f"):
col = col.as_numerical_column("d")
if col.dtype == cudf.dtype("float32"):
col = col.astype("float64")

if downcast:
if downcast == "float":
Expand Down Expand Up @@ -205,7 +205,7 @@ def _convert_str_col(col, errors, _downcast=None):

is_integer = libstrings.is_integer(col)
if is_integer.all():
return col.as_numerical_column(dtype=cudf.dtype("i8"))
return col.astype(dtype=cudf.dtype("i8"))

col = _proc_inf_empty_strings(col)

Expand All @@ -218,9 +218,9 @@ def _convert_str_col(col, errors, _downcast=None):
"limited by float32 precision."
)
)
return col.as_numerical_column(dtype=cudf.dtype("f"))
return col.astype(dtype=cudf.dtype("float32"))
else:
return col.as_numerical_column(dtype=cudf.dtype("d"))
return col.astype(dtype=cudf.dtype("float64"))
else:
if errors == "coerce":
col = libcudf.string_casting.stod(col)
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,9 @@ def test_from_pandas_intervaldtype():
result = cudf.from_pandas(dtype)
expected = cudf.IntervalDtype("int64", closed="left")
assert_eq(result, expected)


def test_intervaldtype_eq_string_with_attributes():
dtype = cudf.IntervalDtype("int64", closed="left")
assert dtype == "interval"
assert dtype == "interval[int64, left]"
Loading

0 comments on commit 5efd72f

Please sign in to comment.