Skip to content

Commit

Permalink
REF (string dtype): de-duplicate _str_map (2) (pandas-dev#59451)
Browse files Browse the repository at this point in the history
* REF (string): de-duplicate _str_map (2)

* mypy fixup
  • Loading branch information
jbrockmendel authored and WillAyd committed Sep 20, 2024
1 parent 22d9b39 commit 1dc46e5
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 143 deletions.
179 changes: 90 additions & 89 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,57 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

from pandas.arrays import BooleanArray

if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
constructor: type[IntegerArray | BooleanArray]
if is_integer_dtype(dtype):
constructor = IntegerArray
else:
constructor = BooleanArray

na_value_is_na = isna(na_value)
if na_value_is_na:
na_value = 1
elif dtype == np.dtype("bool"):
# GH#55736
na_value = bool(na_value)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(cast(type, dtype)),
)

if not na_value_is_na:
mask[:] = False

return constructor(result, mask)

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map_str_or_object(
self,
dtype,
Expand Down Expand Up @@ -373,6 +424,45 @@ def _str_map_str_or_object(
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
if na_value_is_na:
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True

result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result[mask] = np.nan
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
Expand Down Expand Up @@ -727,95 +817,6 @@ def _cmp_method(self, other, op):
# base class "NumpyExtensionArray" defined the type as "float")
_str_na_value = libmissing.NA # type: ignore[assignment]

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
if na_value_is_na:
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True

result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result[mask] = np.nan
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

from pandas.arrays import BooleanArray

if dtype is None:
dtype = StringDtype(storage="python")
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
constructor: type[IntegerArray | BooleanArray]
if is_integer_dtype(dtype):
constructor = IntegerArray
else:
constructor = BooleanArray

na_value_is_na = isna(na_value)
if na_value_is_na:
na_value = 1
elif dtype == np.dtype("bool"):
na_value = bool(na_value)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(dtype), # type: ignore[arg-type]
)

if not na_value_is_na:
mask[:] = False

return constructor(result, mask)

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)


class StringArrayNumpySemantics(StringArray):
_storage = "python"
Expand Down
56 changes: 2 additions & 54 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def _data(self):
# base class "ObjectStringArrayMixin" defined the type as "float")
_str_na_value = libmissing.NA # type: ignore[assignment]

_str_map = BaseStringArray._str_map

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
Expand Down Expand Up @@ -322,60 +324,6 @@ def _str_map_nan_semantics(
else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

# TODO: de-duplicate with StringArray method. This method is moreless copy and
# paste.

from pandas.arrays import (
BooleanArray,
IntegerArray,
)

if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
constructor: type[IntegerArray | BooleanArray]
if is_integer_dtype(dtype):
constructor = IntegerArray
else:
constructor = BooleanArray

na_value_is_na = isna(na_value)
if na_value_is_na:
na_value = 1
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(dtype), # type: ignore[arg-type]
)

if not na_value_is_na:
mask[:] = False

return constructor(result, mask)

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
):
Expand Down

0 comments on commit 1dc46e5

Please sign in to comment.