Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use argument dtype to inform coercion #17779

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,39 @@ def infer_dtype_from_array(arr, pandas_dtype=False):
return arr.dtype, arr


def maybe_infer_dtype_type(element):
"""Try to infer an object's dtype, for use in arithmetic ops

Uses `element.dtype` if that's available.
Objects implementing the iterator protocol are cast to a NumPy array,
and from there the array's type is used.

Parameters
----------
element : object
Possibly has a `.dtype` attribute, and possibly the iterator
protocol.

Returns
-------
tipo : type

Examples
--------
>>> from collections import namedtuple
>>> Foo = namedtuple("Foo", "dtype")
>>> maybe_infer_dtype_type(Foo(np.dtype("i8")))
numpy.int64
"""
tipo = None
if hasattr(element, 'dtype'):
tipo = element.dtype
elif is_list_like(element):
element = np.asarray(element)
tipo = element.dtype
return tipo


def maybe_upcast(values, fill_value=np.nan, dtype=None, copy=False):
""" provide explict type promotion and coercion

Expand Down
54 changes: 27 additions & 27 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
soft_convert_objects,
maybe_convert_objects,
astype_nansafe,
find_common_type)
find_common_type,
maybe_infer_dtype_type)
from pandas.core.dtypes.missing import (
isna, notna, array_equivalent,
_isna_compat,
Expand Down Expand Up @@ -629,10 +630,9 @@ def convert(self, copy=True, **kwargs):
def _can_hold_element(self, element):
""" require the same dtype as ourselves """
dtype = self.values.dtype.type
if is_list_like(element):
element = np.asarray(element)
tipo = element.dtype.type
return issubclass(tipo, dtype)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type, dtype)
return isinstance(element, dtype)

def _try_cast_result(self, result, dtype=None):
Expand Down Expand Up @@ -1806,11 +1806,10 @@ class FloatBlock(FloatOrComplexBlock):
_downcast_dtype = 'int64'

def _can_hold_element(self, element):
if is_list_like(element):
element = np.asarray(element)
tipo = element.dtype.type
return (issubclass(tipo, (np.floating, np.integer)) and
not issubclass(tipo, (np.datetime64, np.timedelta64)))
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return (issubclass(tipo.type, (np.floating, np.integer)) and
not issubclass(tipo.type, (np.datetime64, np.timedelta64)))
return (isinstance(element, (float, int, np.floating, np.int_)) and
not isinstance(element, (bool, np.bool_, datetime, timedelta,
np.datetime64, np.timedelta64)))
Expand Down Expand Up @@ -1856,9 +1855,9 @@ class ComplexBlock(FloatOrComplexBlock):
is_complex = True

def _can_hold_element(self, element):
if is_list_like(element):
element = np.array(element)
return issubclass(element.dtype.type,
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type,
(np.floating, np.integer, np.complexfloating))
return (isinstance(element,
(float, int, complex, np.float_, np.int_)) and
Expand All @@ -1874,12 +1873,12 @@ class IntBlock(NumericBlock):
_can_hold_na = False

def _can_hold_element(self, element):
if is_list_like(element):
element = np.array(element)
tipo = element.dtype.type
return (issubclass(tipo, np.integer) and
not issubclass(tipo, (np.datetime64, np.timedelta64)) and
self.dtype.itemsize >= element.dtype.itemsize)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return (issubclass(tipo.type, np.integer) and
not issubclass(tipo.type, (np.datetime64,
np.timedelta64)) and
self.dtype.itemsize >= tipo.itemsize)
return is_integer(element)

def should_store(self, value):
Expand Down Expand Up @@ -1917,10 +1916,9 @@ def _box_func(self):
return lambda x: tslib.Timedelta(x, unit='ns')

def _can_hold_element(self, element):
if is_list_like(element):
element = np.array(element)
tipo = element.dtype.type
return issubclass(tipo, np.timedelta64)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type, np.timedelta64)
return isinstance(element, (timedelta, np.timedelta64))

def fillna(self, value, **kwargs):
Expand Down Expand Up @@ -2018,9 +2016,9 @@ class BoolBlock(NumericBlock):
_can_hold_na = False

def _can_hold_element(self, element):
if is_list_like(element):
element = np.asarray(element)
return issubclass(element.dtype.type, np.bool_)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type, np.bool_)
return isinstance(element, (bool, np.bool_))

def should_store(self, value):
Expand Down Expand Up @@ -2450,7 +2448,9 @@ def _astype(self, dtype, mgr=None, **kwargs):
return super(DatetimeBlock, self)._astype(dtype=dtype, **kwargs)

def _can_hold_element(self, element):
if is_list_like(element):
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
# TODO: this still uses asarray, instead of dtype.type
element = np.array(element)
return element.dtype == _NS_DTYPE or element.dtype == np.int64
return (is_integer(element) or isinstance(element, datetime) or
Expand Down
62 changes: 62 additions & 0 deletions pandas/tests/internals/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=W0102

from datetime import datetime, date
import operator
import sys
import pytest
import numpy as np
Expand Down Expand Up @@ -1213,3 +1214,64 @@ def assert_add_equals(val, inc, result):

with pytest.raises(ValueError):
BlockPlacement(slice(2, None, -1)).add(-1)


class DummyElement(object):
def __init__(self, value, dtype):
self.value = value
self.dtype = np.dtype(dtype)

def __array__(self):
return np.array(self.value, dtype=self.dtype)

def __str__(self):
return "DummyElement({}, {})".format(self.value, self.dtype)

def __repr__(self):
return str(self)

def astype(self, dtype, copy=False):
self.dtype = dtype
return self

def view(self, dtype):
return type(self)(self.value.view(dtype), dtype)

def any(self, axis=None):
return bool(self.value)


class TestCanHoldElement(object):
@pytest.mark.parametrize('value, dtype', [
(1, 'i8'),
(1.0, 'f8'),
(1j, 'complex128'),
(True, 'bool'),
(np.timedelta64(20, 'ns'), '<m8[ns]'),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually ought to add Timedelta/Timestamp, NaT, np.nan, Period

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing data ones are already covered in test_nanops. Timedelta / timestamps in handled in tests/scalar/. Period doesn't support binops. So we should be good.

(np.datetime64(20, 'ns'), '<M8[ns]'),
])
@pytest.mark.parametrize('op', [
operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.mod,
operator.pow,
], ids=lambda x: x.__name__)
def test_binop_other(self, op, value, dtype):
skip = {(operator.add, 'bool'),
(operator.sub, 'bool'),
(operator.mul, 'bool'),
(operator.truediv, 'bool'),
(operator.mod, 'i8'),
(operator.mod, 'complex128'),
(operator.mod, '<M8[ns]'),
(operator.mod, '<m8[ns]'),
(operator.pow, 'bool')}
if (op, dtype) in skip:
pytest.skip("Invalid combination {},{}".format(op, dtype))
e = DummyElement(value, dtype)
s = pd.DataFrame({"A": [e.value, e.value]}, dtype=e.dtype)
result = op(s, e).dtypes
expected = op(s, value).dtypes
assert_series_equal(result, expected)