Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: dispatch to EA.astype #22343

Merged
merged 10 commits into from
Aug 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ ExtensionType Changes
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
-
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).

.. _whatsnew_0240.api.incompatibilities:

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas.compat import u, range
from pandas.compat import set_function_name

from pandas.core.dtypes.cast import astype_nansafe
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass
from pandas.core.dtypes.common import (
is_integer, is_scalar, is_float,
Expand Down Expand Up @@ -391,7 +392,7 @@ def astype(self, dtype, copy=True):

# coerce
data = self._coerce_to_ndarray()
return data.astype(dtype=dtype, copy=False)
return astype_nansafe(data, dtype, copy=None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary so that

In [8]: s.values
Out[8]: IntegerArray([0, nan, 2], dtype='Int8')

In [9]: s.astype('uint32')

still raises. Previously, it did

In [4]: pd.core.arrays.IntegerArray([1, None, 2], dtype='uint8').astype('uint32')
Out[4]: array([         1, 2415919104,          2], dtype=uint32)

which I don't think we want. This was only tested at the Series[IntegerArray].astype level, which never called EA.astype

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for this?


@property
def _ndarray_values(self):
Expand Down
15 changes: 13 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,16 @@ def conv(r, dtype):

def astype_nansafe(arr, dtype, copy=True):
""" return a view if copy is False, but
need to be very careful as the result shape could change! """
need to be very careful as the result shape could change!

Parameters
----------
arr : ndarray
dtype : np.dtype
copy : bool, default True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specify here what happens with False? (will always do view (except for object), even if the itemsize is incompatible)

If False, a view will be attempted but may fail, if
e.g. the itemsizes don't align.
"""

# dispatch on extension dtype if needed
if is_extension_array_dtype(dtype):
Expand Down Expand Up @@ -733,8 +742,10 @@ def astype_nansafe(arr, dtype, copy=True):
FutureWarning, stacklevel=5)
dtype = np.dtype(dtype.name + "[ns]")

if copy:
if copy or is_object_dtype(arr) or is_object_dtype(dtype):
# Explicit copy, or required since NumPy can't view from / to object.
return arr.astype(dtype, copy=True)

return arr.view(dtype)


Expand Down
27 changes: 15 additions & 12 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,22 +637,25 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
# force the copy here
if values is None:

if issubclass(dtype.type,
(compat.text_type, compat.string_types)):
if self.is_extension:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomAugspurger are you really really sure this is needed?
all of this already works with IntegerArray so not t really sure what problem you are trying to solve

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for #22325.

Sparse has special semantics for .astypeing, Series[sparse].astype(np.dtype) is interpreted as astyping the underlying .values.sp_values, rather than densifying and asytping. I'd like to deprecate this, but that's another matter.

In general though, it seems like EAs should have a say in how they're astyped, rather than always going through .get_values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already do this is my point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not on master? That hit's https://github.com/pandas-dev/pandas/blob/master/pandas/core/internals/blocks.py#L652, which converts to an ndarray, before ever calling the extension array's astype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least, my new tests fail on master without these changes. I'm not sure if / how IntegerArray is being handled differently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the current code in master or this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR

there is a ton of code in astype to dispatch to extension types already
this adds yet another branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point to this "ton of code"? I don't see any other dispatch to EAs.
What is done is the generic conversion to array and then the use of astype_nansafe (which has a dispatch to EAs, but for the target EA dtype, not for the calling EA dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a 2 line change, this extra if condition. Which of those two lines is the convoluted one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I could swear I add an .is_extension branch already, that's why this looks weird. In any event this has lots of if/then condition. Please make a note / issue to clean this up.

values = self.values.astype(dtype)
else:
if issubclass(dtype.type,
(compat.text_type, compat.string_types)):

# use native type formatting for datetime/tz/timedelta
if self.is_datelike:
values = self.to_native_types()
# use native type formatting for datetime/tz/timedelta
if self.is_datelike:
values = self.to_native_types()

# astype formatting
else:
values = self.get_values()
# astype formatting
else:
values = self.get_values()

else:
values = self.get_values(dtype=dtype)
else:
values = self.get_values(dtype=dtype)

# _astype_nansafe works fine with 1-d only
values = astype_nansafe(values.ravel(), dtype, copy=True)
# _astype_nansafe works fine with 1-d only
values = astype_nansafe(values.ravel(), dtype, copy=True)

# TODO(extension)
# should we make this attribute?
Expand Down
28 changes: 24 additions & 4 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype):
name = 'decimal'
na_value = decimal.Decimal('NaN')

def __init__(self, context=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemed like a good idea to have a parameterized ExtensionDtype in the tests folder. For .astype, we're testing that Series[decimal].astype(DecimalDtype(new_context)) is passed all the way through.

self.context = context or decimal.getcontext()

def __eq__(self, other):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our default eq is dangerous for parametrized dtypes, depending on whether the parameters appear in the name :/

if isinstance(other, type(self)):
return self.context == other.context
return super(DecimalDtype, self).__eq__(other)

def __repr__(self):
return 'DecimalDtype(context={})'.format(self.context)

@classmethod
def construct_array_type(cls):
"""Return the array type associated with this dtype
Expand All @@ -35,13 +46,12 @@ def construct_from_string(cls, string):


class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin):
dtype = DecimalDtype()

def __init__(self, values, dtype=None, copy=False):
def __init__(self, values, dtype=None, copy=False, context=None):
for val in values:
if not isinstance(val, self.dtype.type):
if not isinstance(val, decimal.Decimal):
raise TypeError("All values must be of type " +
str(self.dtype.type))
str(decimal.Decimal))
values = np.asarray(values, dtype=object)

self._data = values
Expand All @@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False):
# those aliases are currently not working due to assumptions
# in internal code (GH-20735)
# self._values = self.values = self.data
self._dtype = DecimalDtype(context)

@property
def dtype(self):
return self._dtype

@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
Expand Down Expand Up @@ -82,6 +97,11 @@ def copy(self, deep=False):
return type(self)(self._data.copy())
return type(self)(self)

def astype(self, dtype, copy=True):
if isinstance(dtype, type(self.dtype)):
return type(self)(self._data, context=dtype.context)
return super(DecimalArray, self).astype(dtype, copy)

def __setitem__(self, key, value):
if pd.api.types.is_list_like(value):
value = [decimal.Decimal(v) for v in value]
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,27 @@ def test_dataframe_constructor_with_dtype():
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("frame", [True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we add a test in base which raises NotImplementedError so that authors are forced to cover this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is it useful to force EA authors to do it? This is basically checking the pandas dispatch (which we can do here), not the actual EA.astype implementation

Copy link
Contributor Author

@TomAugspurger TomAugspurger Aug 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or how about something like

@pytest.mark.parametrize('typ, check', [
    ('category', 'is_categorical_dtype'),
    ...
])
def test_astype_category(self, data):
    assert check(data.astype(dtype))

Would that make sense as a base test? I think our default implementation would need to be updated to not fail that.

My main concern is that it would be difficult to override (e.g. skip) just some of the dtypes, so maybe we would have to write those as separate tests?

Copy link
Contributor Author

@TomAugspurger TomAugspurger Aug 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mildly concerned that EA authors will not correctly handle extension types. Our base implementation currently fails to handle them. Although we document it as Cast to a NumPy array with 'dtype'.. Should we make the return type Union[ndarray, ExtensionArray]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave such a test until we better figure out how to handle those cross EA/non-EA astype calls (the discussion we were having above in this PR )

Should we make the return type Union[ndarray, ExtensionArray]?

Yeah, probably yes in any case (since EAs authors that provide multiple dtypes already do that, like IntegerArray)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add an issue to make this test, not blocking this PR on it.

def test_astype_dispatches(frame):
# This is a dtype-specific test that ensures Series[decimal].astype
# gets all the way through to ExtensionArray.astype
# Designing a reliable smoke test that works for arbitrary data types
# is difficult.
data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a')
ctx = decimal.Context()
ctx.prec = 5

if frame:
data = data.to_frame()

result = data.astype(DecimalDtype(ctx))

if frame:
result = result['a']

assert result.dtype.context.prec == ctx.prec


class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):

def check_opname(self, s, op_name, other, exc=None):
Expand Down
9 changes: 9 additions & 0 deletions pandas/tests/extension/integer/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,15 @@ def test_cross_type_arithmetic():
tm.assert_series_equal(result, expected)


def test_astype_nansafe():
# https://github.com/pandas-dev/pandas/pull/22343
arr = IntegerArray([np.nan, 1, 2], dtype="Int8")

with tm.assert_raises_regex(
ValueError, 'cannot convert float NaN to integer'):
arr.astype('uint32')


# TODO(jreback) - these need testing / are broken

# shift
Expand Down