Skip to content

Commit

Permalink
BUG/Perf: Support ExtensionArrays in where (#24114)
Browse files Browse the repository at this point in the history
Closes #24077
  • Loading branch information
TomAugspurger authored and jreback committed Dec 10, 2018
1 parent c5a4711 commit baad046
Show file tree
Hide file tree
Showing 15 changed files with 323 additions and 11 deletions.
6 changes: 6 additions & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ changes were made:
- ``SparseDataFrame.combine`` and ``DataFrame.combine_first`` no longer supports combining a sparse column with a dense column while preserving the sparse subtype. The result will be an object-dtype SparseArray.
- Setting :attr:`SparseArray.fill_value` to a fill value with a different dtype is now allowed.
- ``DataFrame[column]`` is now a :class:`Series` with sparse values, rather than a :class:`SparseSeries`, when slicing a single column with sparse values (:issue:`23559`).
- The result of :meth:`Series.where` is now a ``Series`` with sparse values, like with other extension arrays (:issue:`24077`)

Some new warnings are issued for operations that require or are likely to materialize a large dense array:

Expand Down Expand Up @@ -1113,6 +1114,8 @@ Deprecations
- :func:`pandas.types.is_datetimetz` is deprecated in favor of `pandas.types.is_datetime64tz` (:issue:`23917`)
- Creating a :class:`TimedeltaIndex` or :class:`DatetimeIndex` by passing range arguments `start`, `end`, and `periods` is deprecated in favor of :func:`timedelta_range` and :func:`date_range` (:issue:`23919`)
- Passing a string alias like ``'datetime64[ns, UTC]'`` as the `unit` parameter to :class:`DatetimeTZDtype` is deprecated. Use :class:`DatetimeTZDtype.construct_from_string` instead (:issue:`23990`).
- In :meth:`Series.where` with Categorical data, providing an ``other`` that is not present in the categories is deprecated. Convert the categorical to a different dtype or add the ``other`` to the categories first (:issue:`24077`).


.. _whatsnew_0240.deprecations.datetimelike_int_ops:

Expand Down Expand Up @@ -1223,6 +1226,7 @@ Performance Improvements
- Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`)
- Fixed a performance regression on Windows with Python 3.7 of :func:`pd.read_csv` (:issue:`23516`)
- Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`)
- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`)

.. _whatsnew_0240.docs:

Expand All @@ -1249,6 +1253,7 @@ Categorical
- In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`)
- Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`)
- Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`)
- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`)

Datetimelike
^^^^^^^^^^^^
Expand Down Expand Up @@ -1285,6 +1290,7 @@ Datetimelike
- Bug in :class:`DatetimeIndex` where calling ``np.array(dtindex, dtype=object)`` would incorrectly return an array of ``long`` objects (:issue:`23524`)
- Bug in :class:`Index` where passing a timezone-aware :class:`DatetimeIndex` and `dtype=object` would incorrectly raise a ``ValueError`` (:issue:`23524`)
- Bug in :class:`Index` where calling ``np.array(dtindex, dtype=object)`` on a timezone-naive :class:`DatetimeIndex` would return an array of ``datetime`` objects instead of :class:`Timestamp` objects, potentially losing nanosecond portions of the timestamps (:issue:`23524`)
- Bug in :class:`Categorical.__setitem__` not allowing setting with another ``Categorical`` when both are undordered and have the same categories, but in a different order (:issue:`24142`)
- Bug in :func:`date_range` where using dates with millisecond resolution or higher could return incorrect values or the wrong number of values in the index (:issue:`24110`)

Timedelta
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def __setitem__(self, key, value):
# example, a string like '2018-01-01' is coerced to a datetime
# when setting on a datetime64ns array. In general, if the
# __init__ method coerces that value, then so should __setitem__
# Note, also, that Series/DataFrame.where internally use __setitem__
# on a copy of the data.
raise NotImplementedError(_not_implemented_message.format(
type(self), '__setitem__')
)
Expand Down
12 changes: 11 additions & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,11 +2078,21 @@ def __setitem__(self, key, value):
`Categorical` does not have the same categories
"""

if isinstance(value, (ABCIndexClass, ABCSeries)):
value = value.array

# require identical categories set
if isinstance(value, Categorical):
if not value.categories.equals(self.categories):
if not is_dtype_equal(self, value):
raise ValueError("Cannot set a Categorical with another, "
"without identical categories")
if not self.categories.equals(value.categories):
new_codes = _recode_for_categories(
value.codes, value.categories, self.categories
)
value = Categorical.from_codes(new_codes,
categories=self.categories,
ordered=self.ordered)

rvalue = value if is_list_like(value) else [value]

Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ def __array__(self, dtype=None, copy=True):

def __setitem__(self, key, value):
# I suppose we could allow setting of non-fill_value elements.
# TODO(SparseArray.__setitem__): remove special cases in
# ExtensionBlock.where
msg = "SparseArray does not support item assignment via setitem"
raise TypeError(msg)

Expand Down
5 changes: 4 additions & 1 deletion pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,13 @@ def _can_reindex(self, indexer):

@Appender(_index_shared_docs['where'])
def where(self, cond, other=None):
# TODO: Investigate an alternative implementation with
# 1. copy the underyling Categorical
# 2. setitem with `cond` and `other`
# 3. Rebuild CategoricalIndex.
if other is None:
other = self._na_value
values = np.where(cond, self.values, other)

cat = Categorical(values, dtype=self.dtype)
return self._shallow_copy(cat, **self._get_attributes_dict())

Expand Down
82 changes: 80 additions & 2 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from pandas.core.dtypes.dtypes import (
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype)
from pandas.core.dtypes.generic import (
ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries)
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
ABCSeries)
from pandas.core.dtypes.missing import (
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)

Expand Down Expand Up @@ -1886,7 +1887,6 @@ def take_nd(self, indexer, axis=0, new_mgr_locs=None, fill_tuple=None):
new_values = self.values.take(indexer, fill_value=fill_value,
allow_fill=True)

# if we are a 1-dim object, then always place at 0
if self.ndim == 1 and new_mgr_locs is None:
new_mgr_locs = [0]
else:
Expand Down Expand Up @@ -1967,6 +1967,57 @@ def shift(self, periods, axis=0):
placement=self.mgr_locs,
ndim=self.ndim)]

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
# Extract the underlying arrays.
if isinstance(other, (ABCIndexClass, ABCSeries)):
other = other.array

elif isinstance(other, ABCDataFrame):
# ExtensionArrays are 1-D, so if we get here then
# `other` should be a DataFrame with a single column.
assert other.shape[1] == 1
other = other.iloc[:, 0].array

if isinstance(cond, ABCDataFrame):
assert cond.shape[1] == 1
cond = cond.iloc[:, 0].array

elif isinstance(cond, (ABCIndexClass, ABCSeries)):
cond = cond.array

if lib.is_scalar(other) and isna(other):
# The default `other` for Series / Frame is np.nan
# we want to replace that with the correct NA value
# for the type
other = self.dtype.na_value

if is_sparse(self.values):
# TODO(SparseArray.__setitem__): remove this if condition
# We need to re-infer the type of the data after doing the
# where, for cases where the subtypes don't match
dtype = None
else:
dtype = self.dtype

try:
result = self.values.copy()
icond = ~cond
if lib.is_scalar(other):
result[icond] = other
else:
result[icond] = other[icond]
except (NotImplementedError, TypeError):
# NotImplementedError for class not implementing `__setitem__`
# TypeError for SparseArray, which implements just to raise
# a TypeError
result = self._holder._from_sequence(
np.where(cond, self.values, other),
dtype=dtype,
)

return self.make_block_same_class(result, placement=self.mgr_locs)

@property
def _ftype(self):
return getattr(self.values, '_pandas_ftype', Block._ftype)
Expand Down Expand Up @@ -2658,6 +2709,33 @@ def concat_same_type(self, to_concat, placement=None):
values, placement=placement or slice(0, len(values), 1),
ndim=self.ndim)

def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
# TODO(CategoricalBlock.where):
# This can all be deleted in favor of ExtensionBlock.where once
# we enforce the deprecation.
object_msg = (
"Implicitly converting categorical to object-dtype ndarray. "
"One or more of the values in 'other' are not present in this "
"categorical's categories. A future version of pandas will raise "
"a ValueError when 'other' contains different categories.\n\n"
"To preserve the current behavior, add the new categories to "
"the categorical before calling 'where', or convert the "
"categorical to a different dtype."
)
try:
# Attempt to do preserve categorical dtype.
result = super(CategoricalBlock, self).where(
other, cond, align, errors, try_cast, axis, transpose
)
except (TypeError, ValueError):
warnings.warn(object_msg, FutureWarning, stacklevel=6)
result = self.astype(object).where(other, cond, align=align,
errors=errors,
try_cast=try_cast,
axis=axis, transpose=transpose)
return result


class DatetimeBlock(DatetimeLikeBlockMixin, Block):
__slots__ = ()
Expand Down
94 changes: 94 additions & 0 deletions pandas/tests/arrays/categorical/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest

import pandas as pd
from pandas import Categorical, CategoricalIndex, Index, PeriodIndex, Series
import pandas.core.common as com
from pandas.tests.arrays.categorical.common import TestCategorical
Expand Down Expand Up @@ -43,6 +44,45 @@ def test_setitem(self):

tm.assert_categorical_equal(c, expected)

@pytest.mark.parametrize('other', [
pd.Categorical(['b', 'a']),
pd.Categorical(['b', 'a'], categories=['b', 'a']),
])
def test_setitem_same_but_unordered(self, other):
# GH-24142
target = pd.Categorical(['a', 'b'], categories=['a', 'b'])
mask = np.array([True, False])
target[mask] = other[mask]
expected = pd.Categorical(['b', 'b'], categories=['a', 'b'])
tm.assert_categorical_equal(target, expected)

@pytest.mark.parametrize('other', [
pd.Categorical(['b', 'a'], categories=['b', 'a', 'c']),
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c']),
pd.Categorical(['a', 'a'], categories=['a']),
pd.Categorical(['b', 'b'], categories=['b']),
])
def test_setitem_different_unordered_raises(self, other):
# GH-24142
target = pd.Categorical(['a', 'b'], categories=['a', 'b'])
mask = np.array([True, False])
with pytest.raises(ValueError):
target[mask] = other[mask]

@pytest.mark.parametrize('other', [
pd.Categorical(['b', 'a']),
pd.Categorical(['b', 'a'], categories=['b', 'a'], ordered=True),
pd.Categorical(['b', 'a'], categories=['a', 'b', 'c'], ordered=True),
])
def test_setitem_same_ordered_rasies(self, other):
# Gh-24142
target = pd.Categorical(['a', 'b'], categories=['a', 'b'],
ordered=True)
mask = np.array([True, False])

with pytest.raises(ValueError):
target[mask] = other[mask]


class TestCategoricalIndexing(object):

Expand Down Expand Up @@ -122,6 +162,60 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class):
tm.assert_numpy_array_equal(expected, result)
tm.assert_numpy_array_equal(exp_miss, res_miss)

def test_where_unobserved_nan(self):
ser = pd.Series(pd.Categorical(['a', 'b']))
result = ser.where([True, False])
expected = pd.Series(pd.Categorical(['a', None],
categories=['a', 'b']))
tm.assert_series_equal(result, expected)

# all NA
ser = pd.Series(pd.Categorical(['a', 'b']))
result = ser.where([False, False])
expected = pd.Series(pd.Categorical([None, None],
categories=['a', 'b']))
tm.assert_series_equal(result, expected)

def test_where_unobserved_categories(self):
ser = pd.Series(
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
)
result = ser.where([True, True, False], other='b')
expected = pd.Series(
Categorical(['a', 'b', 'b'], categories=ser.cat.categories)
)
tm.assert_series_equal(result, expected)

def test_where_other_categorical(self):
ser = pd.Series(
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'])
)
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'])
result = ser.where([True, False, True], other)
expected = pd.Series(Categorical(['a', 'c', 'c'], dtype=ser.dtype))
tm.assert_series_equal(result, expected)

def test_where_warns(self):
ser = pd.Series(Categorical(['a', 'b', 'c']))
with tm.assert_produces_warning(FutureWarning):
result = ser.where([True, False, True], 'd')

expected = pd.Series(np.array(['a', 'd', 'c'], dtype='object'))
tm.assert_series_equal(result, expected)

def test_where_ordered_differs_rasies(self):
ser = pd.Series(
Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'],
ordered=True)
)
other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'],
ordered=True)
with tm.assert_produces_warning(FutureWarning):
result = ser.where([True, False, True], other)

expected = pd.Series(np.array(['a', 'c', 'c'], dtype=object))
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("index", [True, False])
def test_mask_with_boolean(index):
Expand Down
14 changes: 13 additions & 1 deletion pandas/tests/arrays/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import pytest

from pandas import Index, IntervalIndex, date_range, timedelta_range
import pandas as pd
from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range
from pandas.core.arrays import IntervalArray
import pandas.util.testing as tm

Expand Down Expand Up @@ -50,6 +51,17 @@ def test_set_closed(self, closed, new_closed):
expected = IntervalArray.from_breaks(range(10), closed=new_closed)
tm.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize('other', [
Interval(0, 1, closed='right'),
IntervalArray.from_breaks([1, 2, 3, 4], closed='right'),
])
def test_where_raises(self, other):
ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4],
closed='left'))
match = "'value.closed' is 'right', expected 'left'."
with pytest.raises(ValueError, match=match):
ser.where([True, False, True], other=other)


class TestSetitem(object):

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/arrays/sparse/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,10 @@ def setitem():
def setslice():
self.arr[1:5] = 2

with pytest.raises(TypeError, match="item assignment"):
with pytest.raises(TypeError, match="assignment via setitem"):
setitem()

with pytest.raises(TypeError, match="item assignment"):
with pytest.raises(TypeError, match="assignment via setitem"):
setslice()

def test_constructor_from_too_large_array(self):
Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/arrays/test_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,21 @@ def test_sub_period():
arr - other


# ----------------------------------------------------------------------------
# Methods

@pytest.mark.parametrize('other', [
pd.Period('2000', freq='H'),
period_array(['2000', '2001', '2000'], freq='H')
])
def test_where_different_freq_raises(other):
ser = pd.Series(period_array(['2000', '2001', '2002'], freq='D'))
cond = np.array([True, False, True])
with pytest.raises(IncompatibleFrequency,
match="Input has different freq=H"):
ser.where(cond, other)


# ----------------------------------------------------------------------------
# Printing

Expand Down
Loading

0 comments on commit baad046

Please sign in to comment.