Skip to content

Commit

Permalink
API: dispatch to EA.astype
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Aug 14, 2018
1 parent f7f266c commit a7ba8f6
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 20 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ ExtensionType Changes
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
-
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).

.. _whatsnew_0240.api.incompatibilities:

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas.compat import u, range
from pandas.compat import set_function_name

from pandas.core.dtypes.cast import astype_nansafe
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass
from pandas.core.dtypes.common import (
is_integer, is_scalar, is_float,
Expand Down Expand Up @@ -391,7 +392,7 @@ def astype(self, dtype, copy=True):

# coerce
data = self._coerce_to_ndarray()
return data.astype(dtype=dtype, copy=False)
return astype_nansafe(data, dtype, copy=None)

@property
def _ndarray_values(self):
Expand Down
23 changes: 21 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,17 @@ def conv(r, dtype):

def astype_nansafe(arr, dtype, copy=True):
""" return a view if copy is False, but
need to be very careful as the result shape could change! """
need to be very careful as the result shape could change!
Parameters
----------
arr : ndarray
dtype : np.dtype
copy : bool or None, default True
Whether to copy during the `.astype` (True) or
just return a view (False). Passing `copy=None` will
attempt to return a view, but will copy if necessary.
"""

# dispatch on extension dtype if needed
if is_extension_array_dtype(dtype):
Expand Down Expand Up @@ -735,7 +745,16 @@ def astype_nansafe(arr, dtype, copy=True):

if copy:
return arr.astype(dtype, copy=True)
return arr.view(dtype)
else:
try:
return arr.view(dtype)
except TypeError:
if copy is None:
# allowed to copy if necessary (e.g. object)
return arr.astype(dtype, copy=True)
else:
raise



def maybe_convert_objects(values, convert_dates=True, convert_numeric=True,
Expand Down
27 changes: 15 additions & 12 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,22 +637,25 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
# force the copy here
if values is None:

if issubclass(dtype.type,
(compat.text_type, compat.string_types)):
if self.is_extension:
values = self.values.astype(dtype)
else:
if issubclass(dtype.type,
(compat.text_type, compat.string_types)):

# use native type formatting for datetime/tz/timedelta
if self.is_datelike:
values = self.to_native_types()
# use native type formatting for datetime/tz/timedelta
if self.is_datelike:
values = self.to_native_types()

# astype formatting
else:
values = self.get_values()
# astype formatting
else:
values = self.get_values()

else:
values = self.get_values(dtype=dtype)
else:
values = self.get_values(dtype=dtype)

# _astype_nansafe works fine with 1-d only
values = astype_nansafe(values.ravel(), dtype, copy=True)
# _astype_nansafe works fine with 1-d only
values = astype_nansafe(values.ravel(), dtype, copy=True)

# TODO(extension)
# should we make this attribute?
Expand Down
28 changes: 24 additions & 4 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype):
name = 'decimal'
na_value = decimal.Decimal('NaN')

def __init__(self, context=None):
self.context = context or decimal.getcontext()

def __eq__(self, other):
if isinstance(other, type(self)):
return self.context == other.context
return super(DecimalDtype, self).__eq__(other)

def __repr__(self):
return 'DecimalDtype(context={})'.format(self.context)

@classmethod
def construct_array_type(cls):
"""Return the array type associated with this dtype
Expand All @@ -35,13 +46,12 @@ def construct_from_string(cls, string):


class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin):
dtype = DecimalDtype()

def __init__(self, values, dtype=None, copy=False):
def __init__(self, values, dtype=None, copy=False, context=None):
for val in values:
if not isinstance(val, self.dtype.type):
if not isinstance(val, decimal.Decimal):
raise TypeError("All values must be of type " +
str(self.dtype.type))
str(decimal.Decimal))
values = np.asarray(values, dtype=object)

self._data = values
Expand All @@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False):
# those aliases are currently not working due to assumptions
# in internal code (GH-20735)
# self._values = self.values = self.data
self._dtype = DecimalDtype(context)

@property
def dtype(self):
return self._dtype

@classmethod
def _from_sequence(cls, scalars, dtype=None, copy=False):
Expand Down Expand Up @@ -82,6 +97,11 @@ def copy(self, deep=False):
return type(self)(self._data.copy())
return type(self)(self)

def astype(self, dtype, copy=True):
if isinstance(dtype, type(self.dtype)):
return type(self)(self._data, context=dtype.context)
return super().astype(dtype, copy)

def __setitem__(self, key, value):
if pd.api.types.is_list_like(value):
value = [decimal.Decimal(v) for v in value]
Expand Down
18 changes: 18 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,24 @@ def test_dataframe_constructor_with_dtype():
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("frame", [True, False])
def test_astype_dispatches(frame):
data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a')
ctx = decimal.Context()
ctx.prec = 5

if frame:
data = data.to_frame()

result = data.astype(DecimalDtype(ctx))

if frame:
result = result['a']

assert result.dtype.context.prec == ctx.prec



class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):

def check_opname(self, s, op_name, other, exc=None):
Expand Down

0 comments on commit a7ba8f6

Please sign in to comment.