Skip to content

Commit

Permalink
REF (string): de-duplicate str_map_nan_semantics (pandas-dev#59464)
Browse files Browse the repository at this point in the history
REF: de-duplicate str_map_nan_semantics
  • Loading branch information
jbrockmendel authored and WillAyd committed Aug 15, 2024
1 parent 489297a commit aa34a1a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 46 deletions.
9 changes: 5 additions & 4 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _str_map(
return constructor(result, mask)

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
return self._str_map_str_or_object(dtype, na_value, arr, f, mask)

def _str_map_str_or_object(
self,
Expand All @@ -400,7 +400,6 @@ def _str_map_str_or_object(
arr: np.ndarray,
f,
mask: npt.NDArray[np.bool_],
convert: bool,
):
# _str_map helper for case where dtype is either string dtype or object
if is_string_dtype(dtype) and not is_object_dtype(dtype):
Expand Down Expand Up @@ -434,7 +433,6 @@ def _str_map_nan_semantics(

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
Expand All @@ -453,6 +451,9 @@ def _str_map_nan_semantics(
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
# TODO: we could alternatively do this check before map_infer_mask
# and adjust the dtype/na_value we pass there. Which is more
# performant?
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
Expand All @@ -461,7 +462,7 @@ def _str_map_nan_semantics(
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)
return self._str_map_str_or_object(dtype, na_value, arr, f, mask)


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
Expand Down
42 changes: 0 additions & 42 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
TYPE_CHECKING,
Callable,
Union,
cast,
)
import warnings

Expand All @@ -24,8 +23,6 @@
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import (
is_bool_dtype,
is_integer_dtype,
is_scalar,
pandas_dtype,
)
Expand Down Expand Up @@ -285,45 +282,6 @@ def _data(self):

_str_map = BaseStringArray._str_map

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False

dtype = np.dtype(cast(type, dtype))
if mask.any():
# numpy int/bool dtypes cannot hold NaNs so we must convert to
# float64 for int (to match maybe_convert_objects) or
# object for bool (again to match maybe_convert_objects)
if is_integer_dtype(dtype):
dtype = np.dtype("float64")
else:
dtype = np.dtype(object)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=dtype,
)
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
):
Expand Down

0 comments on commit aa34a1a

Please sign in to comment.