Skip to content

Commit

Permalink
ENH: ExtensionArray.searchsorted (#24350)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger authored and jreback committed Dec 28, 2018
1 parent c1af4f5 commit 7617ed1
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ update the ``ExtensionDtype._metadata`` tuple to match the signature of your
- :meth:`~pandas.api.types.ExtensionArray.repeat` has been added (:issue:`24349`)
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)
- :meth:`~pandas.api.types.ExtensionArray.searchsorted` has been added (:issue:`24350`)
- An ``ExtensionArray`` with a boolean dtype now works correctly as a boolean indexer. :meth:`pandas.api.types.is_bool_dtype` now properly considers them boolean (:issue:`22326`)
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
Expand Down
49 changes: 49 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ExtensionArray(object):
* unique
* factorize / _values_for_factorize
* argsort / _values_for_argsort
* searchsorted
The remaining methods implemented on this class should be performant,
as they only compose abstract methods. Still, a more efficient
Expand Down Expand Up @@ -518,6 +519,54 @@ def unique(self):
uniques = unique(self.astype(object))
return self._from_sequence(uniques, dtype=self.dtype)

def searchsorted(self, value, side="left", sorter=None):
"""
Find indices where elements should be inserted to maintain order.
.. versionadded:: 0.24.0
Find the indices into a sorted array `self` (a) such that, if the
corresponding elements in `v` were inserted before the indices, the
order of `self` would be preserved.
Assuming that `a` is sorted:
====== ============================
`side` returned index `i` satisfies
====== ============================
left ``self[i-1] < v <= self[i]``
right ``self[i-1] <= v < self[i]``
====== ============================
Parameters
----------
value : array_like
Values to insert into `self`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given.
If 'right', return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `self`).
sorter : 1-D array_like, optional
Optional array of integer indices that sort array a into ascending
order. They are typically the result of argsort.
Returns
-------
indices : array of ints
Array of insertion points with the same shape as `value`.
See Also
--------
numpy.searchsorted : Similar method from NumPy.
"""
# Note: the base tests provided by pandas only test the basics.
# We do not test
# 1. Values outside the range of the `data_for_sorting` fixture
# 2. Values between the values in the `data_for_sorting` fixture
# 3. Missing values.
arr = self.astype(object)
return arr.searchsorted(value, side=side, sorter=sorter)

def _values_for_factorize(self):
# type: () -> Tuple[ndarray, Any]
"""
Expand Down
10 changes: 10 additions & 0 deletions pandas/core/arrays/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,16 @@ def _take_without_fill(self, indices):

return taken

def searchsorted(self, v, side="left", sorter=None):
msg = "searchsorted requires high memory usage."
warnings.warn(msg, PerformanceWarning, stacklevel=2)
if not is_scalar(v):
v = np.asarray(v)
v = np.asarray(v)
return np.asarray(self, dtype=self.dtype.subtype).searchsorted(
v, side, sorter
)

def copy(self, deep=False):
if deep:
values = self.sp_values.copy()
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,7 @@ def factorize(self, sort=False, na_sentinel=-1):
@Appender(_shared_docs['searchsorted'])
def searchsorted(self, value, side='left', sorter=None):
# needs coercion on the key (DatetimeIndex does already)
return self.values.searchsorted(value, side=side, sorter=sorter)
return self._values.searchsorted(value, side=side, sorter=sorter)

def drop_duplicates(self, keep='first', inplace=False):
inplace = validate_bool_kwarg(inplace, 'inplace')
Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,31 @@ def test_hash_pandas_object_works(self, data, as_frame):
b = pd.util.hash_pandas_object(data)
self.assert_equal(a, b)

@pytest.mark.parametrize("as_series", [True, False])
def test_searchsorted(self, data_for_sorting, as_series):
b, c, a = data_for_sorting
arr = type(data_for_sorting)._from_sequence([a, b, c])

if as_series:
arr = pd.Series(arr)
assert arr.searchsorted(a) == 0
assert arr.searchsorted(a, side="right") == 1

assert arr.searchsorted(b) == 1
assert arr.searchsorted(b, side="right") == 2

assert arr.searchsorted(c) == 2
assert arr.searchsorted(c, side="right") == 3

result = arr.searchsorted(arr.take([0, 2]))
expected = np.array([0, 2], dtype=np.intp)

tm.assert_numpy_array_equal(result, expected)

# sorter
sorter = np.array([1, 2, 0])
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0

@pytest.mark.parametrize("as_frame", [True, False])
def test_where_series(self, data, na_value, as_frame):
assert data[0] != data[1]
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def test_where_series(self, data, na_value):
# with shapes (4,) (4,) (0,)
super().test_where_series(data, na_value)

@pytest.mark.skip(reason="Can't compare dicts.")
def test_searchsorted(self, data_for_sorting):
super(TestMethods, self).test_searchsorted(data_for_sorting)


class TestCasting(BaseJSON, base.BaseCastingTests):
@pytest.mark.skip(reason="failing on np.array(self, dtype=str)")
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ def test_combine_add(self, data_repeated):
def test_fillna_length_mismatch(self, data_missing):
super().test_fillna_length_mismatch(data_missing)

def test_searchsorted(self, data_for_sorting):
if not data_for_sorting.ordered:
raise pytest.skip(reason="searchsorted requires ordered data.")


class TestCasting(base.BaseCastingTests):
pass
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ def test_combine_first(self, data):
pytest.skip("TODO(SparseArray.__setitem__ will preserve dtype.")
super(TestMethods, self).test_combine_first(data)

@pytest.mark.parametrize("as_series", [True, False])
def test_searchsorted(self, data_for_sorting, as_series):
with tm.assert_produces_warning(PerformanceWarning):
super(TestMethods, self).test_searchsorted(data_for_sorting,
as_series=as_series)


class TestCasting(BaseSparseTests, base.BaseCastingTests):
pass
Expand Down

0 comments on commit 7617ed1

Please sign in to comment.