diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index b41931a803053..c7a1e006e8b73 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -293,7 +293,7 @@ Sparse ExtensionArray ^^^^^^^^^^^^^^ - +- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with ExtensionArray dtype (:issue:`38729`) - - diff --git a/pandas/core/generic.py b/pandas/core/generic.py index bdb28c10a0ad2..9f84447b7476d 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -89,9 +89,10 @@ import pandas as pd from pandas.core import arraylike, indexing, missing, nanops import pandas.core.algorithms as algos +from pandas.core.arrays import ExtensionArray from pandas.core.base import PandasObject, SelectionMixin import pandas.core.common as com -from pandas.core.construction import create_series_with_explicit_dtype +from pandas.core.construction import create_series_with_explicit_dtype, extract_array from pandas.core.flags import Flags from pandas.core.indexes import base as ibase from pandas.core.indexes.api import ( @@ -8786,6 +8787,9 @@ def _where( """ inplace = validate_bool_kwarg(inplace, "inplace") + if axis is not None: + axis = self._get_axis_number(axis) + # align the cond to same shape as myself cond = com.apply_if_callable(cond, self) if isinstance(cond, NDFrame): @@ -8825,22 +8829,39 @@ def _where( if other.ndim <= self.ndim: _, other = self.align( - other, join="left", axis=axis, level=level, fill_value=np.nan + other, + join="left", + axis=axis, + level=level, + fill_value=np.nan, + copy=False, ) # if we are NOT aligned, raise as we cannot where index - if axis is None and not all( - other._get_axis(i).equals(ax) for i, ax in enumerate(self.axes) - ): + if axis is None and not other._indexed_same(self): raise InvalidIndexError + elif other.ndim < self.ndim: + # TODO(EA2D): avoid object-dtype cast in EA case GH#38729 + other = other._values + if axis == 0: + other = np.reshape(other, (-1, 1)) + elif axis == 1: + other = np.reshape(other, (1, -1)) + + other = np.broadcast_to(other, self.shape) + # slice me out of the other else: raise NotImplementedError( "cannot align with a higher dimensional NDFrame" ) - if isinstance(other, np.ndarray): + if not isinstance(other, (MultiIndex, NDFrame)): + # mainly just catching Index here + other = extract_array(other, extract_numpy=True) + + if isinstance(other, (np.ndarray, ExtensionArray)): if other.shape != self.shape: @@ -8885,10 +8906,10 @@ def _where( else: align = self._get_axis_number(axis) == 1 - if align and isinstance(other, NDFrame): - other = other.reindex(self._info_axis, axis=self._info_axis_number) if isinstance(cond, NDFrame): - cond = cond.reindex(self._info_axis, axis=self._info_axis_number) + cond = cond.reindex( + self._info_axis, axis=self._info_axis_number, copy=False + ) block_axis = self._get_block_manager_axis(axis) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index ea1b8259eeadd..d42039e710666 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1064,9 +1064,7 @@ def putmask(self, mask, new, axis: int = 0) -> List["Block"]: # If the default repeat behavior in np.putmask would go in the # wrong direction, then explicitly repeat and reshape new instead if getattr(new, "ndim", 0) >= 1: - if self.ndim - 1 == new.ndim and axis == 1: - new = np.repeat(new, new_values.shape[-1]).reshape(self.shape) - new = new.astype(new_values.dtype) + new = new.astype(new_values.dtype, copy=False) # we require exact matches between the len of the # values we are setting (or is compat). np.putmask @@ -1104,13 +1102,6 @@ def putmask(self, mask, new, axis: int = 0) -> List["Block"]: new = new.T axis = new_values.ndim - axis - 1 - # Pseudo-broadcast - if getattr(new, "ndim", 0) >= 1: - if self.ndim - 1 == new.ndim: - new_shape = list(new.shape) - new_shape.insert(axis, 1) - new = new.reshape(tuple(new_shape)) - # operate column-by-column def f(mask, val, idx): diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index acdb5726e4adb..87d2fd37ab023 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -653,3 +653,22 @@ def test_where_categorical_filtering(self): expected.loc[0, :] = np.nan tm.assert_equal(result, expected) + + def test_where_ea_other(self): + # GH#38729/GH#38742 + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + arr = pd.array([7, pd.NA, 9]) + ser = Series(arr) + mask = np.ones(df.shape, dtype=bool) + mask[1, :] = False + + # TODO: ideally we would get Int64 instead of object + result = df.where(mask, ser, axis=0) + expected = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}).astype(object) + tm.assert_frame_equal(result, expected) + + ser2 = Series(arr[:2], index=["A", "B"]) + expected = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]}) + expected["B"] = expected["B"].astype(object) + result = df.where(mask, ser2, axis=1) + tm.assert_frame_equal(result, expected)