Skip to content

Commit

Permalink
BUG: SparseArray doesn't recalc indices. (pandas-dev#44956, pandas-de…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdrum committed Dec 30, 2021
1 parent e892d46 commit 6f50d87
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 24 deletions.
7 changes: 4 additions & 3 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,13 +1680,14 @@ def _cmp_method(self, other, op) -> SparseArray:
op_name = op.__name__.strip("_")
return _sparse_array_op(self, other, op, op_name)
else:
# scalar
with np.errstate(all="ignore"):
fill_value = op(self.fill_value, other)
result = op(self.sp_values, other)
mask = np.full(len(self), fill_value, dtype=np.bool_)
mask[self.sp_index.indices] = op(self.sp_values, other)

return type(self)(
result,
sparse_index=self.sp_index,
mask,
fill_value=fill_value,
dtype=np.bool_,
)
Expand Down
2 changes: 2 additions & 0 deletions pandas/tests/arrays/sparse/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def mix(request):
return request.param


# FIXME: There are not SparseArray tests. There are numpy array tests.
# We don't check indices. See GH #45110, #44956, XXX
class TestSparseArrayArithmetics:

_base = np.array
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/arrays/sparse/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def test_getitem_bool_sparse_array(self):
exp = SparseArray([np.nan, 3, 5])
tm.assert_sp_array_equal(res, exp)

# GH 45110
arr = SparseArray([1, 2, 3, 4, np.nan, np.nan], fill_value=np.nan)
res = arr[arr > 2]
exp = SparseArray([3.0, 4.0], fill_value=np.nan)
tm.assert_sp_array_equal(res, exp)

def test_get_item(self):

assert np.isnan(self.arr[1])
Expand Down
63 changes: 42 additions & 21 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def data_for_grouping(request):
return SparseArray([1, 1, np.nan, np.nan, 2, 2, 1, 3], fill_value=request.param)


@pytest.fixture(params=[0, np.nan])
def data_for_compare(request):
return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param)


class BaseSparseTests:
def _check_unsupported(self, data):
if data.dtype == SparseDtype(int, 0):
Expand Down Expand Up @@ -432,32 +437,48 @@ def _check_divmod_op(self, ser, op, other, exc=NotImplementedError):
super()._check_divmod_op(ser, op, other, exc=None)


class TestComparisonOps(BaseSparseTests, base.BaseComparisonOpsTests):
def _compare_other(self, s, data, comparison_op, other):
class TestComparisonOps(BaseSparseTests):
def _compare_other(self, data_for_compare: SparseArray, comparison_op, other):
op = comparison_op

# array
result = pd.Series(op(data, other))
# hard to test the fill value, since we don't know what expected
# is in general.
# Rely on tests in `tests/sparse` to validate that.
assert isinstance(result.dtype, SparseDtype)
assert result.dtype.subtype == np.dtype("bool")

with np.errstate(all="ignore"):
expected = pd.Series(
SparseArray(
op(np.asarray(data), np.asarray(other)),
fill_value=result.values.fill_value,
)
result = op(data_for_compare, other)
assert isinstance(result, SparseArray)
assert result.dtype.subtype == np.bool_

if isinstance(other, SparseArray):
expected = SparseArray(
op(data_for_compare.to_dense(), np.asarray(other)),
fill_value=op(data_for_compare.fill_value, other.fill_value),
dtype=np.bool_,
)
else:
expected = SparseArray(
op(data_for_compare.to_dense(), np.asarray(other)),
fill_value=np.all(
op(np.asarray(data_for_compare.fill_value), np.asarray(other))
),
dtype=np.bool_,
)

tm.assert_series_equal(result, expected)
tm.assert_sp_array_equal(result, expected)

# series
ser = pd.Series(data)
result = op(ser, other)
tm.assert_series_equal(result, expected)
def test_scalar(self, data_for_compare: SparseArray, comparison_op):
self._compare_other(data_for_compare, comparison_op, 0)
self._compare_other(data_for_compare, comparison_op, 1)
self._compare_other(data_for_compare, comparison_op, -1)
self._compare_other(data_for_compare, comparison_op, np.nan)

@pytest.mark.xfail(reason="Wrong indices")
def test_array(self, data_for_compare: SparseArray, comparison_op):
arr = np.linspace(-4, 5, 10)
self._compare_other(data_for_compare, comparison_op, arr)

@pytest.mark.xfail(reason="Wrong indices")
def test_sparse_array(self, data_for_compare: SparseArray, comparison_op):
arr = data_for_compare + 1
self._compare_other(data_for_compare, comparison_op, arr)
arr = data_for_compare * 2
self._compare_other(data_for_compare, comparison_op, arr)


class TestPrinting(BaseSparseTests, base.BasePrintingTests):
Expand Down

0 comments on commit 6f50d87

Please sign in to comment.