Skip to content

Commit

Permalink
REF (string): de-duplicate ArrowStringArray methods (#59555)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Sep 11, 2024
1 parent 16b7288 commit 4444e52
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 172 deletions.
83 changes: 83 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import partial
import re
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -48,6 +49,37 @@ def _convert_int_result(self, result):
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
raise NotImplementedError

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -128,6 +160,33 @@ def _str_slice_replace(
stop = np.iinfo(np.int64).max
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_replace(
self,
pat: str | re.Pattern,
repl: str | Callable,
n: int = -1,
case: bool = True,
flags: int = 0,
regex: bool = True,
) -> Self:
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, or flags!=0"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
# https://github.com/apache/arrow/issues/39149
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
pa_max_replacements = None if n < 0 else n
result = func(
self._pa_array,
pattern=pat,
replacement=repl,
max_replacements=pa_max_replacements,
)
return type(self)(result)

def _str_capitalize(self) -> Self:
return type(self)(pc.utf8_capitalize(self._pa_array))

Expand All @@ -137,6 +196,16 @@ def _str_title(self) -> Self:
def _str_swapcase(self) -> Self:
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
Expand Down Expand Up @@ -228,6 +297,20 @@ def _str_contains(
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if (
pa_version_under13p0
Expand Down
86 changes: 1 addition & 85 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def _rank(
"""
See Series.rank.__doc__.
"""
return type(self)(
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
Expand Down Expand Up @@ -2323,57 +2323,13 @@ def _str_count(self, pat: str, flags: int = 0) -> Self:
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))

def _result_converter(self, result):
return type(self)(result)

def _str_replace(
self,
pat: str | re.Pattern,
repl: str | Callable,
n: int = -1,
case: bool = True,
flags: int = 0,
regex: bool = True,
) -> Self:
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, or flags!=0"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
# https://github.com/apache/arrow/issues/39149
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
pa_max_replacements = None if n < 0 else n
result = func(
self._pa_array,
pattern=pat,
replacement=repl,
max_replacements=pa_max_replacements,
)
return type(self)(result)

def _str_repeat(self, repeats: int | Sequence[int]) -> Self:
if not isinstance(repeats, int):
raise NotImplementedError(
f"repeat is not implemented when repeats is {type(repeats).__name__}"
)
return type(self)(pc.binary_repeat(self._pa_array, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
) -> Self:
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
) -> Self:
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_join(self, sep: str) -> Self:
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
self._pa_array.type
Expand All @@ -2394,46 +2350,6 @@ def _str_rpartition(self, sep: str, expand: bool) -> Self:
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_len(self) -> Self:
return type(self)(pc.utf8_length(self._pa_array))

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_casefold(self) -> Self:
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
Expand Down
106 changes: 19 additions & 87 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
NpDtype,
Scalar,
Self,
npt,
)
Expand Down Expand Up @@ -290,6 +288,20 @@ def astype(self, dtype, copy: bool = True):
_str_startswith = ArrowStringArrayMixin._str_startswith
_str_endswith = ArrowStringArrayMixin._str_endswith
_str_pad = ArrowStringArrayMixin._str_pad
_str_match = ArrowStringArrayMixin._str_match
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
_str_lower = ArrowStringArrayMixin._str_lower
_str_upper = ArrowStringArrayMixin._str_upper
_str_strip = ArrowStringArrayMixin._str_strip
_str_lstrip = ArrowStringArrayMixin._str_lstrip
_str_rstrip = ArrowStringArrayMixin._str_rstrip
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_get = ArrowStringArrayMixin._str_get
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace
_str_len = ArrowStringArrayMixin._str_len
_str_slice = ArrowStringArrayMixin._str_slice

def _str_contains(
Expand Down Expand Up @@ -323,73 +335,21 @@ def _str_replace(
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
return super()._str_replace(pat, repl, n, case, flags, regex)

return ArrowExtensionArray._str_replace(self, pat, repl, n, case, flags, regex)
return ArrowStringArrayMixin._str_replace(
self, pat, repl, n, case, flags, regex
)

def _str_repeat(self, repeats: int | Sequence[int]):
if not isinstance(repeats, int):
return super()._str_repeat(repeats)
else:
return type(self)(pc.binary_repeat(self._pa_array, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)
return ArrowExtensionArray._str_repeat(self, repeats=repeats)

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
return ArrowStringArrayMixin._str_removeprefix(self, prefix)
return super()._str_removeprefix(prefix)

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
Expand Down Expand Up @@ -456,28 +416,6 @@ def _reduce(
else:
return result

def _rank(
self,
*,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
na_option=na_option,
ascending=ascending,
pct=pct,
)
)

def value_counts(self, dropna: bool = True) -> Series:
result = super().value_counts(dropna=dropna)
if self.dtype.na_value is np.nan:
Expand All @@ -499,9 +437,3 @@ def _cmp_method(self, other, op):

class ArrowStringArrayNumpySemantics(ArrowStringArray):
_na_value = np.nan
_str_get = ArrowStringArrayMixin._str_get
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace

0 comments on commit 4444e52

Please sign in to comment.