Skip to content

Commit

Permalink
Series.str.find fix for pd.ArrowDtype(pa.string()) (pandas-dev#56792)
Browse files Browse the repository at this point in the history
* fix find

* gh reference

* add test for Nones

* fix min version compat

* restore test

* improve test cases

* fix empty string

* inline

* improve tests

* fix

* Revert "fix"

This reverts commit 7fa21eb.

* fix

* merge

* inline

---------

Co-authored-by: Rohan Jain <rohanjain@microsoft.com>
  • Loading branch information
rohanjain101 and Rohan Jain authored Feb 2, 2024
1 parent 4663edd commit 4f0870e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 18 deletions.
28 changes: 17 additions & 11 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,20 +2364,26 @@ def _str_fullmatch(
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
if start != 0 and end is not None:
if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH 56792
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
return type(self)(pa.chunked_array(result))
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
start_offset = max(0, start)
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
raise NotImplementedError(
f"find not implemented with {sub=}, {start=}, {end=}"
)
result = pc.if_else(found, offset_result, -1)
return type(self)(result)

def _str_join(self, sep: str) -> Self:
Expand Down
80 changes: 73 additions & 7 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BytesIO,
StringIO,
)
from itertools import combinations
import operator
import pickle
import re
Expand Down Expand Up @@ -1933,13 +1934,18 @@ def test_str_fullmatch(pat, case, na, exp):


@pytest.mark.parametrize(
"sub, start, end, exp, exp_typ",
[["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()]],
"sub, start, end, exp, exp_type",
[
["ab", 0, None, [0, None], pa.int32()],
["bc", 1, 3, [1, None], pa.int64()],
["ab", 1, 3, [-1, None], pa.int64()],
["ab", -3, -3, [-1, None], pa.int64()],
],
)
def test_str_find(sub, start, end, exp, exp_typ):
def test_str_find(sub, start, end, exp, exp_type):
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub, start=start, end=end)
expected = pd.Series(exp, dtype=ArrowDtype(exp_typ))
expected = pd.Series(exp, dtype=ArrowDtype(exp_type))
tm.assert_series_equal(result, expected)


Expand All @@ -1951,10 +1957,70 @@ def test_str_find_negative_start():
tm.assert_series_equal(result, expected)


def test_str_find_notimplemented():
def test_str_find_no_end():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
with pytest.raises(NotImplementedError, match="find not implemented"):
ser.str.find("ab", start=1)
if pa_version_under13p0:
# https://github.com/apache/arrow/issues/36311
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
ser.str.find("ab", start=1)
else:
result = ser.str.find("ab", start=1)
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
tm.assert_series_equal(result, expected)


def test_str_find_negative_start_negative_end():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="d", start=-6, end=-3)
expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


def test_str_find_large_start():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
if pa_version_under13p0:
# https://github.com/apache/arrow/issues/36311
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
ser.str.find(sub="d", start=16)
else:
result = ser.str.find(sub="d", start=16)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.skipif(
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
)
@pytest.mark.parametrize("start", list(range(-15, 15)) + [None])
@pytest.mark.parametrize("end", list(range(-15, 15)) + [None])
@pytest.mark.parametrize(
"sub",
["abcaadef"[x:y] for x, y in combinations(range(len("abcaadef") + 1), r=2)]
+ [
"",
"az",
"abce",
],
)
def test_str_find_e2e(start, end, sub):
s = pd.Series(
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
dtype=ArrowDtype(pa.string()),
)
object_series = s.astype(pd.StringDtype())
result = s.str.find(sub, start, end)
expected = object_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result, expected)


def test_str_find_negative_start_negative_end_no_match():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="d", start=-3, end=-6)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4f0870e

Please sign in to comment.