From 4f0870efecff989c211dd8cfe975ef2127cc86b8 Mon Sep 17 00:00:00 2001 From: rohanjain101 <38412262+rohanjain101@users.noreply.github.com> Date: Fri, 2 Feb 2024 12:57:54 -0500 Subject: [PATCH] Series.str.find fix for pd.ArrowDtype(pa.string()) (#56792) * fix find * gh reference * add test for Nones * fix min version compat * restore test * improve test cases * fix empty string * inline * improve tests * fix * Revert "fix" This reverts commit 7fa21eb24682ae587a0b3033942fbe1247f98921. * fix * merge * inline --------- Co-authored-by: Rohan Jain --- pandas/core/arrays/arrow/array.py | 28 ++++++---- pandas/tests/extension/test_arrow.py | 80 +++++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 392b4e3cc616a..7bab8c9395ac6 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2364,20 +2364,26 @@ def _str_fullmatch( return self._str_match(pat, case, flags, na) def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: - if start != 0 and end is not None: + if (start == 0 or start is None) and end is None: + result = pc.find_substring(self._pa_array, sub) + else: + if sub == "": + # GH 56792 + result = self._apply_elementwise(lambda val: val.find(sub, start, end)) + return type(self)(pa.chunked_array(result)) + if start is None: + start_offset = 0 + start = 0 + elif start < 0: + start_offset = pc.add(start, pc.utf8_length(self._pa_array)) + start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset) + else: + start_offset = start slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) result = pc.find_substring(slices, sub) - not_found = pc.equal(result, -1) - start_offset = max(0, start) + found = pc.not_equal(result, pa.scalar(-1, type=result.type)) offset_result = pc.add(result, start_offset) - result = pc.if_else(not_found, result, offset_result) - elif start == 0 and end is None: - slices = self._pa_array - result = pc.find_substring(slices, sub) - else: - raise NotImplementedError( - f"find not implemented with {sub=}, {start=}, {end=}" - ) + result = pc.if_else(found, offset_result, -1) return type(self)(result) def _str_join(self, sep: str) -> Self: diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 62e4629ca7cb7..3ce2b38bf8644 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -23,6 +23,7 @@ BytesIO, StringIO, ) +from itertools import combinations import operator import pickle import re @@ -1933,13 +1934,18 @@ def test_str_fullmatch(pat, case, na, exp): @pytest.mark.parametrize( - "sub, start, end, exp, exp_typ", - [["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [1, None], pa.int64()]], + "sub, start, end, exp, exp_type", + [ + ["ab", 0, None, [0, None], pa.int32()], + ["bc", 1, 3, [1, None], pa.int64()], + ["ab", 1, 3, [-1, None], pa.int64()], + ["ab", -3, -3, [-1, None], pa.int64()], + ], ) -def test_str_find(sub, start, end, exp, exp_typ): +def test_str_find(sub, start, end, exp, exp_type): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) result = ser.str.find(sub, start=start, end=end) - expected = pd.Series(exp, dtype=ArrowDtype(exp_typ)) + expected = pd.Series(exp, dtype=ArrowDtype(exp_type)) tm.assert_series_equal(result, expected) @@ -1951,10 +1957,70 @@ def test_str_find_negative_start(): tm.assert_series_equal(result, expected) -def test_str_find_notimplemented(): +def test_str_find_no_end(): ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) - with pytest.raises(NotImplementedError, match="find not implemented"): - ser.str.find("ab", start=1) + if pa_version_under13p0: + # https://github.com/apache/arrow/issues/36311 + with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): + ser.str.find("ab", start=1) + else: + result = ser.str.find("ab", start=1) + expected = pd.Series([-1, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) + + +def test_str_find_negative_start_negative_end(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=-6, end=-3) + expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_str_find_large_start(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + if pa_version_under13p0: + # https://github.com/apache/arrow/issues/36311 + with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): + ser.str.find(sub="d", start=16) + else: + result = ser.str.find(sub="d", start=16) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.skipif( + pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" +) +@pytest.mark.parametrize("start", list(range(-15, 15)) + [None]) +@pytest.mark.parametrize("end", list(range(-15, 15)) + [None]) +@pytest.mark.parametrize( + "sub", + ["abcaadef"[x:y] for x, y in combinations(range(len("abcaadef") + 1), r=2)] + + [ + "", + "az", + "abce", + ], +) +def test_str_find_e2e(start, end, sub): + s = pd.Series( + ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], + dtype=ArrowDtype(pa.string()), + ) + object_series = s.astype(pd.StringDtype()) + result = s.str.find(sub, start, end) + expected = object_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result, expected) + + +def test_str_find_negative_start_negative_end_no_match(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=-3, end=-6) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) @pytest.mark.parametrize(