Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 56470c3
Author: Tom Augspurger <tom.w.augspurger@gmail.com>
Date:   Wed Dec 5 11:39:48 2018 -0600

    Fixups:

    * Ensure data generated OK.
    * Remove erroneous comments about alignment. That was user error.

commit c4604df
Author: Tom Augspurger <tom.w.augspurger@gmail.com>
Date:   Mon Dec 3 14:23:25 2018 -0600

    API: Added ExtensionArray.where

    We need some way to do `.where` on EA object for DatetimeArray. Adding it
    to the interface is, I think, the easiest way.

    Initially I started to write a version on ExtensionBlock, but it proved
    to be unwieldy. to write a version that performed well for all types.
    It *may* be possible to do using `_ndarray_values` but we'd need a few more
    things around that (missing values, converting an arbitrary array to the
    "same' ndarary_values, error handling, re-constructing). It seemed easier
    to push this down to the array.

    The implementation on ExtensionArray is readable, but likely slow since
    it'll involve a conversion to object-dtype.

    Closes pandas-dev#24077
  • Loading branch information
TomAugspurger committed Dec 5, 2018
1 parent 165f3fd commit 7ec7351
Show file tree
Hide file tree
Showing 17 changed files with 302 additions and 21 deletions.
3 changes: 3 additions & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
- Slicing a single row of a ``DataFrame`` with multiple ExtensionArrays of the same type now preserves the dtype, rather than coercing to object (:issue:`22784`)
- Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`)
- Added :meth:`pandas.api.extensions.ExtensionArray.where` (:issue:`24077`)
- Bug when concatenating multiple ``Series`` with different extension dtypes not casting to object dtype (:issue:`22994`)
- Series backed by an ``ExtensionArray`` now work with :func:`util.hash_pandas_object` (:issue:`23066`)
- Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`)
Expand Down Expand Up @@ -1236,6 +1237,7 @@ Performance Improvements
- Improved performance of :meth:`DatetimeIndex.normalize` and :meth:`Timestamp.normalize` for timezone naive or UTC datetimes (:issue:`23634`)
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)

.. _whatsnew_0240.docs:

Expand All @@ -1262,6 +1264,7 @@ Categorical
- In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`)
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
- Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)

Datetimelike
^^^^^^^^^^^^
Expand Down
35 changes: 35 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ExtensionArray(object):
* unique
* factorize / _values_for_factorize
* argsort / _values_for_argsort
* where
The remaining methods implemented on this class should be performant,
as they only compose abstract methods. Still, a more efficient
Expand Down Expand Up @@ -661,6 +662,40 @@ def take(self, indices, allow_fill=False, fill_value=None):
# pandas.api.extensions.take
raise AbstractMethodError(self)

def where(self, cond, other):
"""
Replace values where the condition is False.
Parameters
----------
cond : ndarray or ExtensionArray
The mask indicating which values should be kept (True)
or replaced from `other` (False).
other : ndarray, ExtensionArray, or scalar
Entries where `cond` is False are replaced with
corresponding value from `other`.
Notes
-----
Note that `cond` and `other` *cannot* be a Series, Index, or callable.
When used from, e.g., :meth:`Series.where`, pandas will unbox
Series and Indexes, and will apply callables before they arrive here.
Returns
-------
ExtensionArray
Same dtype as the original.
See Also
--------
Series.where : Similar method for Series.
DataFrame.where : Similar method for DataFrame.
"""
return type(self)._from_sequence(np.where(cond, self, other),
dtype=self.dtype,
copy=False)

def copy(self, deep=False):
# type: (bool) -> ExtensionArray
"""
Expand Down
43 changes: 43 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=E1101,W0232

import reprlib
import textwrap
from warnings import warn

Expand Down Expand Up @@ -1906,6 +1907,48 @@ def take_nd(self, indexer, allow_fill=None, fill_value=None):

take = take_nd

def where(self, cond, other):
# n.b. this now preserves the type
codes = self._codes
object_msg = (
"Implicitly converting categorical to object-dtype ndarray. "
"The values `{}' are not present in this categorical's "
"categories. A future version of pandas will raise a ValueError "
"when 'other' contains different categories.\n\n"
"To preserve the current behavior, add the new categories to "
"the categorical before calling 'where', or convert the "
"categorical to a different dtype."
)

if is_scalar(other) and isna(other):
other = -1
elif is_scalar(other):
item = self.categories.get_indexer([other]).item()

if item == -1:
# note: when removing this, also remove CategoricalBlock.where
warn(object_msg.format(other), FutureWarning, stacklevel=2)
return np.where(cond, self, other)

other = item

elif is_categorical_dtype(other):
if not is_dtype_equal(self, other):
extra = list(other.categories.difference(self.categories))
warn(object_msg.format(reprlib.repr(extra)), FutureWarning,
stacklevel=2)
return np.where(cond, self, other)
other = _get_codes_for_values(other, self.categories)
# get the codes from other that match our categories
pass
else:
other = np.where(isna(other), -1, other)

new_codes = np.where(cond, codes, other)
return type(self).from_codes(new_codes,
categories=self.categories,
ordered=self.ordered)

def _slice(self, slicer):
"""
Return a slice of myself.
Expand Down
11 changes: 11 additions & 0 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,17 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None,

return self._shallow_copy(left_take, right_take)

def where(self, cond, other):
if is_scalar(other) and isna(other):
lother = rother = other
else:
self._check_closed_matches(other, name='other')
lother = other.left
rother = other.right
left = np.where(cond, self.left, lother)
right = np.where(cond, self.right, rother)
return self._shallow_copy(left, right)

def value_counts(self, dropna=True):
"""
Returns a Series containing counts of each interval.
Expand Down
27 changes: 17 additions & 10 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from pandas._libs import lib
from pandas._libs.tslibs import NaT, iNaT, period as libperiod
from pandas._libs.tslibs.fields import isleapyear_arr
from pandas._libs.tslibs.period import (
Expand Down Expand Up @@ -242,16 +243,6 @@ def _generate_range(cls, start, end, periods, freq, fields):

return subarr, freq

# -----------------------------------------------------------------
# DatetimeLike Interface
def _unbox_scalar(self, value):
assert isinstance(value, self._scalar_type), value
return value.ordinal

def _scalar_from_string(self, value):
assert isinstance(value, self._scalar_type), value
return Period(value, freq=self.freq)

def _check_compatible_with(self, other):
if self.freqstr != other.freqstr:
msg = DIFFERENT_FREQ_INDEX.format(self.freqstr, other.freqstr)
Expand Down Expand Up @@ -357,6 +348,22 @@ def to_timestamp(self, freq=None, how='start'):
# --------------------------------------------------------------------
# Array-like / EA-Interface Methods

def where(self, cond, other):
# TODO(DatetimeArray): move to DatetimeLikeArrayMixin
# n.b. _ndarray_values candidate.
i8 = self.asi8
if lib.is_scalar(other):
if isna(other):
other = iNaT
elif isinstance(other, Period):
self._check_compatible_with(other)
other = other.ordinal
elif isinstance(other, type(self)):
self._check_compatible_with(other)
other = other.asi8
result = np.where(cond, i8, other)
return type(self)._simple_new(result, dtype=self.dtype)

def _formatter(self, boxed=False):
if boxed:
return str
Expand Down
14 changes: 14 additions & 0 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,20 @@ def take(self, indices, allow_fill=False, fill_value=None):
return type(self)(result, fill_value=self.fill_value, kind=self.kind,
**kwargs)

def where(self, cond, other):
if is_scalar(other):
result_dtype = np.result_type(self.dtype.subtype, other)
elif isinstance(other, type(self)):
result_dtype = np.result_type(self.dtype.subtype,
other.dtype.subtype)
else:
result_dtype = np.result_type(self.dtype.subtype, other.dtype)

dtype = self.dtype.update_dtype(result_dtype)
# TODO: avoid converting to dense.
values = np.where(cond, self, other)
return type(self)(values, dtype=dtype)

def _take_with_fill(self, indices, fill_value=None):
if fill_value is None:
fill_value = self.dtype.na_value
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class _DtypeOpsMixin(object):
na_value = np.nan
_metadata = ()

@property
def _ndarray_na_value(self):
"""Private method internal to pandas"""
raise AbstractMethodError(self)

def __eq__(self, other):
"""Check whether 'other' is equal to self.
Expand Down
6 changes: 1 addition & 5 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,7 @@ def _can_reindex(self, indexer):

@Appender(_index_shared_docs['where'])
def where(self, cond, other=None):
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)

cat = Categorical(values, dtype=self.dtype)
cat = self.values.where(cond, other=other)
return self._shallow_copy(cat, **self._get_attributes_dict())

def reindex(self, target, method=None, level=None, limit=None,
Expand Down
38 changes: 37 additions & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from pandas.core.dtypes.dtypes import (
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
from pandas.core.dtypes.generic import (
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
ABCSeries)
from pandas.core.dtypes.inference import is_scalar
from pandas.core.dtypes.missing import (
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)
Expand Down Expand Up @@ -1970,6 +1971,30 @@ def shift(self, periods, axis=0):
placement=self.mgr_locs,
ndim=self.ndim)]

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
if isinstance(other, (ABCIndexClass, ABCSeries)):
other = other.array

if isinstance(cond, ABCDataFrame):
assert cond.shape[1] == 1
cond = cond.iloc[:, 0].array

if isinstance(other, ABCDataFrame):
assert other.shape[1] == 1
other = other.iloc[:, 0].array

if isinstance(cond, (ABCIndexClass, ABCSeries)):
cond = cond.array

if lib.is_scalar(other) and isna(other):
# The default `other` for Series / Frame is np.nan
# we want to replace that with the correct NA value
# for the type
other = self.dtype.na_value
result = self.values.where(cond, other)
return self.make_block_same_class(result, placement=self.mgr_locs)

@property
def _ftype(self):
return getattr(self.values, '_pandas_ftype', Block._ftype)
Expand Down Expand Up @@ -2675,6 +2700,17 @@ def concat_same_type(self, to_concat, placement=None):
values, placement=placement or slice(0, len(values), 1),
ndim=self.ndim)

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
result = super(CategoricalBlock, self).where(
other, cond, align, errors, try_cast, axis, transpose
)
if result.values.dtype != self.values.dtype:
# For backwards compatability, we allow upcasting to object.
# This fallback will be removed in the future.
result = result.astype(object)
return result


class DatetimeBlock(DatetimeLikeBlockMixin, Block):
__slots__ = ()
Expand Down
32 changes: 32 additions & 0 deletions pandas/tests/arrays/categorical/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,38 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
tm.assert_numpy_array_equal(expected, result)
tm.assert_numpy_array_equal(exp_miss, res_miss)

def test_where_unobserved_categories(self):
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
result = arr.where([True, True, False], other='b')
expected = Categorical(['a', 'b', 'b'], categories=arr.categories)
tm.assert_categorical_equal(result, expected)

def test_where_other_categorical(self):
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
result = arr.where([True, False, True], other)
expected = Categorical(['a', 'c', 'c'], dtype=arr.dtype)
tm.assert_categorical_equal(result, expected)

def test_where_warns(self):
arr = Categorical(['a', 'b', 'c'])
with tm.assert_produces_warning(FutureWarning):
result = arr.where([True, False, True], 'd')

expected = np.array(['a', 'd', 'c'], dtype='object')
tm.assert_numpy_array_equal(result, expected)

def test_where_ordered_differs_rasies(self):
arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
ordered=True)
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
ordered=True)
with tm.assert_produces_warning(FutureWarning):
result = arr.where([True, False, True], other)

expected = np.array(['a', 'c', 'c'], dtype=object)
tm.assert_numpy_array_equal(result, expected)


@pytest.mark.parametrize("index", [True, False])
def test_mask_with_boolean(index):
Expand Down
12 changes: 11 additions & 1 deletion pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pytest

from pandas import Index, IntervalIndex, date_range, timedelta_range
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
from pandas.core.arrays import IntervalArray
import pandas.util.testing as tm

Expand Down Expand Up @@ -50,6 +50,16 @@ def test_set_closed(self, closed, new_closed):
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize('other', [
Interval(0, 1, closed='right'),
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
])
def test_where_raises(self, other):
arr = IntervalArray.from_breaks([1, 2, 3, 4], closed='left')
match = "'other.closed' is 'right', expected 'left'."
with pytest.raises(ValueError, match=match):
arr.where([True, False, True], other=other)


class TestSetitem(object):

Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/arrays/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ def test_sub_period():
arr - other


# ----------------------------------------------------------------------------
# Methods

@pytest.mark.parametrize('other', [
pd.Period('2000', freq='H'),
period_array(['2000', '2001', '2000'], freq='H')
])
def test_where_different_freq_raises(other):
arr = period_array(['2000', '2001', '2002'], freq='D')
cond = np.array([True, False, True])
with pytest.raises(IncompatibleFrequency,
match="Input has different freq=H"):
arr.where(cond, other)


# ----------------------------------------------------------------------------
# Printing

Expand Down
Loading

0 comments on commit 7ec7351

Please sign in to comment.