Skip to content

Commit

Permalink
Use argument dtype to inform coercion (#17779)
Browse files Browse the repository at this point in the history
* Use argument dtype to inform coercion

Master:

```python
>>> import dask.dataframe as dd

>>> s = dd.core.Scalar({('s', 0): 10}, 's', 'i8')
>>> pdf = pd.DataFrame({'a': [1, 2, 3, 4, 5, 6, 7],
...                     'b': [7, 6, 5, 4, 3, 2, 1]})
>>> (pdf + s).dtypes
a    object
b    object
dtype: object

Head:

```
>>> (pdf + s).dtypes
a    int64
b    int64
dtype: object
```

This is more consistent with 0.20.3, while still most of the changes in
#16821

Closes #17767

* Compat for older numpy where bool(dtype) is False

* Added timedelta
  • Loading branch information
TomAugspurger authored Oct 5, 2017
1 parent d099f88 commit 6773694
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 27 deletions.
33 changes: 33 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,39 @@ def infer_dtype_from_array(arr, pandas_dtype=False):
return arr.dtype, arr


def maybe_infer_dtype_type(element):
"""Try to infer an object's dtype, for use in arithmetic ops
Uses `element.dtype` if that's available.
Objects implementing the iterator protocol are cast to a NumPy array,
and from there the array's type is used.
Parameters
----------
element : object
Possibly has a `.dtype` attribute, and possibly the iterator
protocol.
Returns
-------
tipo : type
Examples
--------
>>> from collections import namedtuple
>>> Foo = namedtuple("Foo", "dtype")
>>> maybe_infer_dtype_type(Foo(np.dtype("i8")))
numpy.int64
"""
tipo = None
if hasattr(element, 'dtype'):
tipo = element.dtype
elif is_list_like(element):
element = np.asarray(element)
tipo = element.dtype
return tipo


def maybe_upcast(values, fill_value=np.nan, dtype=None, copy=False):
""" provide explict type promotion and coercion
Expand Down
54 changes: 27 additions & 27 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
soft_convert_objects,
maybe_convert_objects,
astype_nansafe,
find_common_type)
find_common_type,
maybe_infer_dtype_type)
from pandas.core.dtypes.missing import (
isna, notna, array_equivalent,
_isna_compat,
Expand Down Expand Up @@ -629,10 +630,9 @@ def convert(self, copy=True, **kwargs):
def _can_hold_element(self, element):
""" require the same dtype as ourselves """
dtype = self.values.dtype.type
if is_list_like(element):
element = np.asarray(element)
tipo = element.dtype.type
return issubclass(tipo, dtype)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type, dtype)
return isinstance(element, dtype)

def _try_cast_result(self, result, dtype=None):
Expand Down Expand Up @@ -1806,11 +1806,10 @@ class FloatBlock(FloatOrComplexBlock):
_downcast_dtype = 'int64'

def _can_hold_element(self, element):
if is_list_like(element):
element = np.asarray(element)
tipo = element.dtype.type
return (issubclass(tipo, (np.floating, np.integer)) and
not issubclass(tipo, (np.datetime64, np.timedelta64)))
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return (issubclass(tipo.type, (np.floating, np.integer)) and
not issubclass(tipo.type, (np.datetime64, np.timedelta64)))
return (isinstance(element, (float, int, np.floating, np.int_)) and
not isinstance(element, (bool, np.bool_, datetime, timedelta,
np.datetime64, np.timedelta64)))
Expand Down Expand Up @@ -1856,9 +1855,9 @@ class ComplexBlock(FloatOrComplexBlock):
is_complex = True

def _can_hold_element(self, element):
if is_list_like(element):
element = np.array(element)
return issubclass(element.dtype.type,
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type,
(np.floating, np.integer, np.complexfloating))
return (isinstance(element,
(float, int, complex, np.float_, np.int_)) and
Expand All @@ -1874,12 +1873,12 @@ class IntBlock(NumericBlock):
_can_hold_na = False

def _can_hold_element(self, element):
if is_list_like(element):
element = np.array(element)
tipo = element.dtype.type
return (issubclass(tipo, np.integer) and
not issubclass(tipo, (np.datetime64, np.timedelta64)) and
self.dtype.itemsize >= element.dtype.itemsize)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return (issubclass(tipo.type, np.integer) and
not issubclass(tipo.type, (np.datetime64,
np.timedelta64)) and
self.dtype.itemsize >= tipo.itemsize)
return is_integer(element)

def should_store(self, value):
Expand Down Expand Up @@ -1917,10 +1916,9 @@ def _box_func(self):
return lambda x: tslib.Timedelta(x, unit='ns')

def _can_hold_element(self, element):
if is_list_like(element):
element = np.array(element)
tipo = element.dtype.type
return issubclass(tipo, np.timedelta64)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type, np.timedelta64)
return isinstance(element, (timedelta, np.timedelta64))

def fillna(self, value, **kwargs):
Expand Down Expand Up @@ -2018,9 +2016,9 @@ class BoolBlock(NumericBlock):
_can_hold_na = False

def _can_hold_element(self, element):
if is_list_like(element):
element = np.asarray(element)
return issubclass(element.dtype.type, np.bool_)
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
return issubclass(tipo.type, np.bool_)
return isinstance(element, (bool, np.bool_))

def should_store(self, value):
Expand Down Expand Up @@ -2450,7 +2448,9 @@ def _astype(self, dtype, mgr=None, **kwargs):
return super(DatetimeBlock, self)._astype(dtype=dtype, **kwargs)

def _can_hold_element(self, element):
if is_list_like(element):
tipo = maybe_infer_dtype_type(element)
if tipo is not None:
# TODO: this still uses asarray, instead of dtype.type
element = np.array(element)
return element.dtype == _NS_DTYPE or element.dtype == np.int64
return (is_integer(element) or isinstance(element, datetime) or
Expand Down
62 changes: 62 additions & 0 deletions pandas/tests/internals/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=W0102

from datetime import datetime, date
import operator
import sys
import pytest
import numpy as np
Expand Down Expand Up @@ -1213,3 +1214,64 @@ def assert_add_equals(val, inc, result):

with pytest.raises(ValueError):
BlockPlacement(slice(2, None, -1)).add(-1)


class DummyElement(object):
def __init__(self, value, dtype):
self.value = value
self.dtype = np.dtype(dtype)

def __array__(self):
return np.array(self.value, dtype=self.dtype)

def __str__(self):
return "DummyElement({}, {})".format(self.value, self.dtype)

def __repr__(self):
return str(self)

def astype(self, dtype, copy=False):
self.dtype = dtype
return self

def view(self, dtype):
return type(self)(self.value.view(dtype), dtype)

def any(self, axis=None):
return bool(self.value)


class TestCanHoldElement(object):
@pytest.mark.parametrize('value, dtype', [
(1, 'i8'),
(1.0, 'f8'),
(1j, 'complex128'),
(True, 'bool'),
(np.timedelta64(20, 'ns'), '<m8[ns]'),
(np.datetime64(20, 'ns'), '<M8[ns]'),
])
@pytest.mark.parametrize('op', [
operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.mod,
operator.pow,
], ids=lambda x: x.__name__)
def test_binop_other(self, op, value, dtype):
skip = {(operator.add, 'bool'),
(operator.sub, 'bool'),
(operator.mul, 'bool'),
(operator.truediv, 'bool'),
(operator.mod, 'i8'),
(operator.mod, 'complex128'),
(operator.mod, '<M8[ns]'),
(operator.mod, '<m8[ns]'),
(operator.pow, 'bool')}
if (op, dtype) in skip:
pytest.skip("Invalid combination {},{}".format(op, dtype))
e = DummyElement(value, dtype)
s = pd.DataFrame({"A": [e.value, e.value]}, dtype=e.dtype)
result = op(s, e).dtypes
expected = op(s, value).dtypes
assert_series_equal(result, expected)

0 comments on commit 6773694

Please sign in to comment.