diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 538d4e7e4a7aa8..5b728c4274f07c 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1680,13 +1680,14 @@ def _cmp_method(self, other, op) -> SparseArray: op_name = op.__name__.strip("_") return _sparse_array_op(self, other, op, op_name) else: + # scalar with np.errstate(all="ignore"): fill_value = op(self.fill_value, other) - result = op(self.sp_values, other) + mask = np.full(len(self), fill_value, dtype=np.bool_) + mask[self.sp_index.indices] = op(self.sp_values, other) return type(self)( - result, - sparse_index=self.sp_index, + mask, fill_value=fill_value, dtype=np.bool_, ) diff --git a/pandas/tests/arrays/sparse/test_arithmetics.py b/pandas/tests/arrays/sparse/test_arithmetics.py index 012fe61fdba05f..787b1f4830629a 100644 --- a/pandas/tests/arrays/sparse/test_arithmetics.py +++ b/pandas/tests/arrays/sparse/test_arithmetics.py @@ -26,6 +26,8 @@ def mix(request): return request.param +# FIXME: There are not SparseArray tests. There are numpy array tests. +# We don't check indices. See GH #45110, #44956, XXX class TestSparseArrayArithmetics: _base = np.array diff --git a/pandas/tests/arrays/sparse/test_array.py b/pandas/tests/arrays/sparse/test_array.py index 2c3dcdeeaf8dc3..f645679399d672 100644 --- a/pandas/tests/arrays/sparse/test_array.py +++ b/pandas/tests/arrays/sparse/test_array.py @@ -266,6 +266,13 @@ def test_getitem_bool_sparse_array(self): exp = SparseArray([np.nan, 3, 5]) tm.assert_sp_array_equal(res, exp) + # GH 45110 + def test_getitem_bool_sparse_array_as_comparison(self): + arr = SparseArray([1, 2, 3, 4, np.nan, np.nan], fill_value=np.nan) + res = arr[arr > 2] + exp = SparseArray([3.0, 4.0], fill_value=np.nan) + tm.assert_sp_array_equal(res, exp) + def test_get_item(self): assert np.isnan(self.arr[1]) diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py index f7809dc2e42175..558b8d7b37a3a2 100644 --- a/pandas/tests/extension/test_sparse.py +++ b/pandas/tests/extension/test_sparse.py @@ -100,6 +100,11 @@ def data_for_grouping(request): return SparseArray([1, 1, np.nan, np.nan, 2, 2, 1, 3], fill_value=request.param) +@pytest.fixture(params=[0, np.nan]) +def data_for_compare(request): + return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param) + + class BaseSparseTests: def _check_unsupported(self, data): if data.dtype == SparseDtype(int, 0): @@ -432,32 +437,48 @@ def _check_divmod_op(self, ser, op, other, exc=NotImplementedError): super()._check_divmod_op(ser, op, other, exc=None) -class TestComparisonOps(BaseSparseTests, base.BaseComparisonOpsTests): - def _compare_other(self, s, data, comparison_op, other): +class TestComparisonOps(BaseSparseTests): + def _compare_other(self, data_for_compare: SparseArray, comparison_op, other): op = comparison_op - # array - result = pd.Series(op(data, other)) - # hard to test the fill value, since we don't know what expected - # is in general. - # Rely on tests in `tests/sparse` to validate that. - assert isinstance(result.dtype, SparseDtype) - assert result.dtype.subtype == np.dtype("bool") - - with np.errstate(all="ignore"): - expected = pd.Series( - SparseArray( - op(np.asarray(data), np.asarray(other)), - fill_value=result.values.fill_value, - ) + result = op(data_for_compare, other) + assert isinstance(result, SparseArray) + assert result.dtype.subtype == np.bool_ + + if isinstance(other, SparseArray): + expected = SparseArray( + op(data_for_compare.to_dense(), np.asarray(other)), + fill_value=op(data_for_compare.fill_value, other.fill_value), + dtype=np.bool_, + ) + else: + expected = SparseArray( + op(data_for_compare.to_dense(), np.asarray(other)), + fill_value=np.all( + op(np.asarray(data_for_compare.fill_value), np.asarray(other)) + ), + dtype=np.bool_, ) - tm.assert_series_equal(result, expected) + tm.assert_sp_array_equal(result, expected) - # series - ser = pd.Series(data) - result = op(ser, other) - tm.assert_series_equal(result, expected) + def test_scalar(self, data_for_compare: SparseArray, comparison_op): + self._compare_other(data_for_compare, comparison_op, 0) + self._compare_other(data_for_compare, comparison_op, 1) + self._compare_other(data_for_compare, comparison_op, -1) + self._compare_other(data_for_compare, comparison_op, np.nan) + + @pytest.mark.xfail(reason="Wrong indices") + def test_array(self, data_for_compare: SparseArray, comparison_op): + arr = np.linspace(-4, 5, 10) + self._compare_other(data_for_compare, comparison_op, arr) + + @pytest.mark.xfail(reason="Wrong indices") + def test_sparse_array(self, data_for_compare: SparseArray, comparison_op): + arr = data_for_compare + 1 + self._compare_other(data_for_compare, comparison_op, arr) + arr = data_for_compare * 2 + self._compare_other(data_for_compare, comparison_op, arr) class TestPrinting(BaseSparseTests, base.BasePrintingTests):