Skip to content

Commit

Permalink
Refactor to avoid call to StringArray__init__ & validation
Browse files Browse the repository at this point in the history
  • Loading branch information
topper-123 committed Sep 16, 2020
1 parent ee01e02 commit 39ea860
Showing 1 changed file with 21 additions and 29 deletions.
50 changes: 21 additions & 29 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import operator
from typing import TYPE_CHECKING, Optional, Type, Union
from typing import TYPE_CHECKING, Type, Union

import numpy as np

Expand Down Expand Up @@ -122,9 +122,6 @@ class StringArray(PandasArray):
copy : bool, default False
Whether to copy the array of data.
convert : bool, default False
If true, force conversion of non-na scalars to strings.
If False, raises a ValueError, if a scalar is neither a string nor na.
Attributes
----------
Expand Down Expand Up @@ -165,15 +162,7 @@ class StringArray(PandasArray):
['1', '1']
Length: 2, dtype: string
Instantiating StringArrays directly with non-strings arrays will raise an error
unless ``convert=True``.
>>> pd.arrays.StringArray(np.array(['1', 1]))
ValueError: StringArray requires a sequence of strings or pandas.NA
>>> pd.arrays.StringArray(['1', 1], convert=True)
<StringArray>
['1', '1']
Length: 2, dtype: string
However, instantiating StringArrays directly with non-strings will raise an error.
For comparison methods, `StringArray` returns a :class:`pandas.BooleanArray`:
Expand All @@ -186,29 +175,22 @@ class StringArray(PandasArray):
# undo the PandasArray hack
_typ = "extension"

def __init__(self, values, copy=False, convert: bool = False):
def __init__(self, values, copy=False):
values = extract_array(values)
if not isinstance(values, type(self)):
if convert:
na_val = StringDtype.na_value
values = lib.ensure_string_array(values, na_value=na_val, copy=copy)
else:
self._validate(values)

super().__init__(values, copy=copy)
self._dtype = StringDtype()
if not isinstance(values, type(self)):
self._validate()

def _validate(self, values: Optional[np.ndarray] = None) -> None:
def _validate(self):
"""Validate that we only store NA or strings."""
if values is None:
values = self._ndarray

if len(values) and not lib.is_string_array(values, skipna=True):
if len(self._ndarray) and not lib.is_string_array(self._ndarray, skipna=True):
raise ValueError("StringArray requires a sequence of strings or pandas.NA")
if values.dtype != "object":
if self._ndarray.dtype != "object":
raise ValueError(
"StringArray requires a sequence of strings or pandas.NA. Got "
f"'{values.dtype}' dtype instead."
f"'{self._ndarray.dtype}' dtype instead."
)

@classmethod
Expand All @@ -217,8 +199,18 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
assert dtype == "string"

result = np.asarray(scalars, dtype="object")

return cls(result, copy=copy, convert=True)
# convert non-na-likes to str, and nan-likes to StringDtype.na_value
result = lib.ensure_string_array(
result, na_value=StringDtype.na_value, copy=copy
)

# Manually creating new array avoids the validation step in the __init__, so is
# faster. Refactor need for validation?
new_string_array = object.__new__(cls)
new_string_array._dtype = StringDtype()
new_string_array._ndarray = result

return new_string_array

@classmethod
def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
Expand Down

0 comments on commit 39ea860

Please sign in to comment.