diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index eab5956735f12..5b2a5314108e9 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -994,6 +994,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your - :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`). - Slicing a single row of a ``DataFrame`` with multiple ExtensionArrays of the same type now preserves the dtype, rather than coercing to object (:issue:`22784`) - Added :meth:`pandas.api.types.register_extension_dtype` to register an extension type with pandas (:issue:`22664`) +- Added :meth:`pandas.api.extensions.ExtensionArray.where` (:issue:`24077`) - Bug when concatenating multiple ``Series`` with different extension dtypes not casting to object dtype (:issue:`22994`) - Series backed by an ``ExtensionArray`` now work with :func:`util.hash_pandas_object` (:issue:`23066`) - Updated the ``.type`` attribute for ``PeriodDtype``, ``DatetimeTZDtype``, and ``IntervalDtype`` to be instances of the dtype (``Period``, ``Timestamp``, and ``Interval`` respectively) (:issue:`22938`) @@ -1236,6 +1237,7 @@ Performance Improvements - Improved performance of :meth:`DatetimeIndex.normalize` and :meth:`Timestamp.normalize` for timezone naive or UTC datetimes (:issue:`23634`) - Improved performance of :meth:`DatetimeIndex.tz_localize` and various ``DatetimeIndex`` attributes with dateutil UTC timezone (:issue:`23772`) - Improved performance of :class:`Categorical` constructor for `Series` objects (:issue:`23814`) +- Improved performance of :meth:`~DataFrame.where` for Categorical data (:issue:`24077`) .. _whatsnew_0240.docs: @@ -1262,6 +1264,7 @@ Categorical - In meth:`Series.unstack`, specifying a ``fill_value`` not present in the categories now raises a ``TypeError`` rather than ignoring the ``fill_value`` (:issue:`23284`) - Bug when resampling :meth:`Dataframe.resample()` and aggregating on categorical data, the categorical dtype was getting lost. (:issue:`23227`) - Bug in many methods of the ``.str``-accessor, which always failed on calling the ``CategoricalIndex.str`` constructor (:issue:`23555`, :issue:`23556`) +- Bug in :meth:`Series.where` losing the categorical dtype for categorical data (:issue:`24077`) Datetimelike ^^^^^^^^^^^^ diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 9c6aa4a12923f..294c5e99d66f4 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -64,6 +64,7 @@ class ExtensionArray(object): * unique * factorize / _values_for_factorize * argsort / _values_for_argsort + * where The remaining methods implemented on this class should be performant, as they only compose abstract methods. Still, a more efficient @@ -661,6 +662,40 @@ def take(self, indices, allow_fill=False, fill_value=None): # pandas.api.extensions.take raise AbstractMethodError(self) + def where(self, cond, other): + """ + Replace values where the condition is False. + + Parameters + ---------- + cond : ndarray or ExtensionArray + The mask indicating which values should be kept (True) + or replaced from `other` (False). + + other : ndarray, ExtensionArray, or scalar + Entries where `cond` is False are replaced with + corresponding value from `other`. + + Notes + ----- + Note that `cond` and `other` *cannot* be a Series, Index, or callable. + When used from, e.g., :meth:`Series.where`, pandas will unbox + Series and Indexes, and will apply callables before they arrive here. + + Returns + ------- + ExtensionArray + Same dtype as the original. + + See Also + -------- + Series.where : Similar method for Series. + DataFrame.where : Similar method for DataFrame. + """ + return type(self)._from_sequence(np.where(cond, self, other), + dtype=self.dtype, + copy=False) + def copy(self, deep=False): # type: (bool) -> ExtensionArray """ diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 938ca53b5fdce..76d956861a9b6 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -1906,6 +1906,34 @@ def take_nd(self, indexer, allow_fill=None, fill_value=None): take = take_nd + def where(self, cond, other): + # n.b. this now preserves the type + codes = self._codes + + if is_scalar(other) and isna(other): + other = -1 + elif is_scalar(other): + item = self.categories.get_indexer([other]).item() + + if item == -1: + raise ValueError("The value '{}' is not present in " + "this Categorical's categories".format(other)) + other = item + + elif is_categorical_dtype(other): + if not is_dtype_equal(self, other): + raise TypeError("The type of 'other' does not match.") + other = _get_codes_for_values(other, self.categories) + # get the codes from other that match our categories + pass + else: + other = np.where(isna(other), -1, other) + + new_codes = np.where(cond, codes, other) + return type(self).from_codes(new_codes, + categories=self.categories, + ordered=self.ordered) + def _slice(self, slicer): """ Return a slice of myself. diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 785fb02c4d95d..fd35f01641750 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -777,6 +777,18 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, return self._shallow_copy(left_take, right_take) + def where(self, cond, other): + if is_scalar(other) and isna(other): + lother = other + rother = other + else: + self._check_closed_matches(other, name='other') + lother = other.left + rother = other.right + left = np.where(cond, self.left, lother) + right = np.where(cond, self.right, rother) + return self._shallow_copy(left, right) + def value_counts(self, dropna=True): """ Returns a Series containing counts of each interval. diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index 4d466ef7281b7..0dff368fcd5f0 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -4,6 +4,7 @@ import numpy as np +from pandas._libs import lib from pandas._libs.tslibs import NaT, iNaT, period as libperiod from pandas._libs.tslibs.fields import isleapyear_arr from pandas._libs.tslibs.period import ( @@ -241,6 +242,11 @@ def _generate_range(cls, start, end, periods, freq, fields): return subarr, freq + def _check_compatible_with(self, other): + if self.freqstr != other.freqstr: + msg = DIFFERENT_FREQ_INDEX.format(self.freqstr, other.freqstr) + raise IncompatibleFrequency(msg) + # -------------------------------------------------------------------- # Data / Attributes @@ -341,6 +347,22 @@ def to_timestamp(self, freq=None, how='start'): # -------------------------------------------------------------------- # Array-like / EA-Interface Methods + def where(self, cond, other): + # TODO(DatetimeArray): move to DatetimeLikeArrayMixin + # n.b. _ndarray_values candidate. + i8 = self.asi8 + if lib.is_scalar(other): + if isna(other): + other = iNaT + elif isinstance(other, Period): + self._check_compatible_with(other) + other = other.ordinal + elif isinstance(other, type(self)): + self._check_compatible_with(other) + other = other.asi8 + result = np.where(cond, i8, other) + return type(self)._simple_new(result, dtype=self.dtype) + def _formatter(self, boxed=False): if boxed: return str diff --git a/pandas/core/arrays/sparse.py b/pandas/core/arrays/sparse.py index 134466d769ada..3897b4efc480b 100644 --- a/pandas/core/arrays/sparse.py +++ b/pandas/core/arrays/sparse.py @@ -1063,6 +1063,20 @@ def take(self, indices, allow_fill=False, fill_value=None): return type(self)(result, fill_value=self.fill_value, kind=self.kind, **kwargs) + def where(self, cond, other): + if is_scalar(other): + result_dtype = np.result_type(self.dtype.subtype, other) + elif isinstance(other, type(self)): + result_dtype = np.result_type(self.dtype.subtype, + other.dtype.subtype) + else: + result_dtype = np.result_type(self.dtype.subtype, other.dtype) + + dtype = self.dtype.update_dtype(result_dtype) + # TODO: avoid converting to dense. + values = np.where(cond, self, other) + return type(self)(values, dtype=dtype) + def _take_with_fill(self, indices, fill_value=None): if fill_value is None: fill_value = self.dtype.na_value diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index aa81e88abf28e..e271e11398678 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -26,6 +26,11 @@ class _DtypeOpsMixin(object): na_value = np.nan _metadata = () + @property + def _ndarray_na_value(self): + """Private method internal to pandas""" + raise AbstractMethodError(self) + def __eq__(self, other): """Check whether 'other' is equal to self. diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 6d26894514a9c..94f932d5e8123 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -501,11 +501,7 @@ def _can_reindex(self, indexer): @Appender(_index_shared_docs['where']) def where(self, cond, other=None): - if other is None: - other = self._na_value - values = np.where(cond, self.values, other) - - cat = Categorical(values, dtype=self.dtype) + cat = self.values.where(cond, other=other) return self._shallow_copy(cat, **self._get_attributes_dict()) def reindex(self, target, method=None, level=None, limit=None, diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 198e832ca4603..40952c0ae0688 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -28,7 +28,8 @@ from pandas.core.dtypes.dtypes import ( CategoricalDtype, DatetimeTZDtype, ExtensionDtype, PandasExtensionDtype) from pandas.core.dtypes.generic import ( - ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, ABCSeries) + ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass, + ABCSeries) from pandas.core.dtypes.missing import ( _isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna) @@ -1967,6 +1968,30 @@ def shift(self, periods, axis=0): placement=self.mgr_locs, ndim=self.ndim)] + def where(self, other, cond, align=True, errors='raise', + try_cast=False, axis=0, transpose=False): + if isinstance(other, (ABCIndexClass, ABCSeries)): + other = other.array + + if isinstance(cond, ABCDataFrame): + assert cond.shape[1] == 1 + cond = cond.iloc[:, 0].array + + if isinstance(other, ABCDataFrame): + assert other.shape[1] == 1 + other = other.iloc[:, 0].array + + if isinstance(cond, (ABCIndexClass, ABCSeries)): + cond = cond.array + + if lib.is_scalar(other) and isna(other): + # The default `other` for Series / Frame is np.nan + # we want to replace that with the correct NA value + # for the type + other = self.dtype.na_value + result = self.values.where(cond, other) + return self.make_block_same_class(result, placement=self.mgr_locs) + @property def _ftype(self): return getattr(self.values, '_pandas_ftype', Block._ftype) diff --git a/pandas/tests/arrays/categorical/test_indexing.py b/pandas/tests/arrays/categorical/test_indexing.py index 8df5728f7d895..73a09b9a67e71 100644 --- a/pandas/tests/arrays/categorical/test_indexing.py +++ b/pandas/tests/arrays/categorical/test_indexing.py @@ -122,6 +122,32 @@ def test_get_indexer_non_unique(self, idx_values, key_values, key_class): tm.assert_numpy_array_equal(expected, result) tm.assert_numpy_array_equal(exp_miss, res_miss) + def test_where_raises(self): + arr = Categorical(['a', 'b', 'c']) + with pytest.raises(ValueError, match="The value 'd'"): + arr.where([True, False, True], 'd') + + def test_where_unobserved_categories(self): + arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a']) + result = arr.where([True, True, False], other='b') + expected = Categorical(['a', 'b', 'b'], categories=arr.categories) + tm.assert_categorical_equal(result, expected) + + def test_where_other_categorical(self): + arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a']) + other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd']) + result = arr.where([True, False, True], other) + expected = Categorical(['a', 'c', 'c'], dtype=arr.dtype) + tm.assert_categorical_equal(result, expected) + + def test_where_ordered_differs_rasies(self): + arr = Categorical(['a', 'b', 'c'], categories=['d', 'c', 'b', 'a'], + ordered=True) + other = Categorical(['b', 'c', 'a'], categories=['a', 'c', 'b', 'd'], + ordered=True) + with pytest.raises(TypeError, match="The type of"): + arr.where([True, False, True], other) + @pytest.mark.parametrize("index", [True, False]) def test_mask_with_boolean(index): diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index a04579dbbb6b1..1bc8f7087e54e 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from pandas import Index, IntervalIndex, date_range, timedelta_range +from pandas import Index, Interval, IntervalIndex, date_range, timedelta_range from pandas.core.arrays import IntervalArray import pandas.util.testing as tm @@ -50,6 +50,16 @@ def test_set_closed(self, closed, new_closed): expected = IntervalArray.from_breaks(range(10), closed=new_closed) tm.assert_extension_array_equal(result, expected) + @pytest.mark.parametrize('other', [ + Interval(0, 1, closed='right'), + IntervalArray.from_breaks([1, 2, 3, 4], closed='right'), + ]) + def test_where_raises(self, other): + arr = IntervalArray.from_breaks([1, 2, 3, 4], closed='left') + match = "'other.closed' is 'right', expected 'left'." + with pytest.raises(ValueError, match=match): + arr.where([True, False, True], other=other) + class TestSetitem(object): diff --git a/pandas/tests/arrays/test_period.py b/pandas/tests/arrays/test_period.py index bf139bb0ce616..f439a268d08ed 100644 --- a/pandas/tests/arrays/test_period.py +++ b/pandas/tests/arrays/test_period.py @@ -197,6 +197,21 @@ def test_sub_period(): arr - other +# ---------------------------------------------------------------------------- +# Methods + +@pytest.mark.parametrize('other', [ + pd.Period('2000', freq='H'), + period_array(['2000', '2001', '2000'], freq='H') +]) +def test_where_different_freq_raises(other): + arr = period_array(['2000', '2001', '2002'], freq='D') + cond = np.array([True, False, True]) + with pytest.raises(IncompatibleFrequency, + match="Input has different freq=H"): + arr.where(cond, other) + + # ---------------------------------------------------------------------------- # Printing diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index e9a89c1af2f22..c3654ffbd64dc 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -198,3 +198,40 @@ def test_hash_pandas_object_works(self, data, as_frame): a = pd.util.hash_pandas_object(data) b = pd.util.hash_pandas_object(data) self.assert_equal(a, b) + + @pytest.mark.parametrize("as_frame", [True, False]) + def test_where_series(self, data, na_value, as_frame): + assert data[0] != data[1] + cls = type(data) + a, b = data[:2] + + ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype)) + cond = np.array([True, True, False, False]) + + if as_frame: + ser = ser.to_frame(name='a') + # TODO: alignment is broken for ndarray `cond` + cond = pd.DataFrame({"a": cond}) + + result = ser.where(cond) + expected = pd.Series(cls._from_sequence([a, a, na_value, na_value], + dtype=data.dtype)) + + if as_frame: + expected = expected.to_frame(name='a') + self.assert_equal(result, expected) + + # array other + cond = np.array([True, False, True, True]) + other = cls._from_sequence([a, b, a, b], dtype=data.dtype) + if as_frame: + # TODO: alignment is broken for ndarray `cond` + other = pd.DataFrame({"a": other}) + # TODO: alignment is broken for array `other` + cond = pd.DataFrame({"a": cond}) + result = ser.where(cond, other) + expected = pd.Series(cls._from_sequence([a, b, b, b], + dtype=data.dtype)) + if as_frame: + expected = expected.to_frame(name='a') + self.assert_equal(result, expected) diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index a941b562fe1ec..4571f3f6d4040 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -221,6 +221,13 @@ def test_combine_add(self, data_repeated): def test_hash_pandas_object_works(self, data, kind): super().test_hash_pandas_object_works(data, kind) + @pytest.mark.skip(reason="broadcasting error") + def test_where_series(self, data, na_value): + # Fails with + # *** ValueError: operands could not be broadcast together + # with shapes (4,) (4,) (0,) + super().test_where_series(data, na_value) + class TestCasting(BaseJSON, base.BaseCastingTests): @pytest.mark.skip(reason="failing on np.array(self, dtype=str)") diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index 891e5f4dd9a95..75327a8b9affe 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -255,6 +255,28 @@ def test_fillna_copy_series(self, data_missing): def test_fillna_length_mismatch(self, data_missing): pass + def test_where_series(self, data, na_value): + assert data[0] != data[1] + cls = type(data) + a, b = data[:2] + + ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype)) + + cond = np.array([True, True, False, False]) + result = ser.where(cond) + # new_dtype is the only difference + new_dtype = SparseDtype('float64', 0.0) + expected = pd.Series(cls._from_sequence([a, a, na_value, na_value], + dtype=new_dtype)) + self.assert_series_equal(result, expected) + + other = cls._from_sequence([a, b, a, b]) + cond = np.array([True, False, True, True]) + result = ser.where(cond, other) + expected = pd.Series(cls._from_sequence([a, b, b, b], + dtype=data.dtype)) + self.assert_series_equal(result, expected) + class TestCasting(BaseSparseTests, base.BaseCastingTests): pass