-
-
Notifications
You must be signed in to change notification settings - Fork 18.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
is_bool_dtype for ExtensionArrays (#22667)
- Loading branch information
1 parent
117d0b1
commit e568fb0
Showing
10 changed files
with
276 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""Rudimentary Apache Arrow-backed ExtensionArray. | ||
At the moment, just a boolean array / type is implemented. | ||
Eventually, we'll want to parametrize the type and support | ||
multiple dtypes. Not all methods are implemented yet, and the | ||
current implementation is not efficient. | ||
""" | ||
import copy | ||
import itertools | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pandas as pd | ||
from pandas.api.extensions import ( | ||
ExtensionDtype, ExtensionArray, take, register_extension_dtype | ||
) | ||
|
||
|
||
@register_extension_dtype | ||
class ArrowBoolDtype(ExtensionDtype): | ||
|
||
type = np.bool_ | ||
kind = 'b' | ||
name = 'arrow_bool' | ||
na_value = pa.NULL | ||
|
||
@classmethod | ||
def construct_from_string(cls, string): | ||
if string == cls.name: | ||
return cls() | ||
else: | ||
raise TypeError("Cannot construct a '{}' from " | ||
"'{}'".format(cls, string)) | ||
|
||
@classmethod | ||
def construct_array_type(cls): | ||
return ArrowBoolArray | ||
|
||
def _is_boolean(self): | ||
return True | ||
|
||
|
||
class ArrowBoolArray(ExtensionArray): | ||
def __init__(self, values): | ||
if not isinstance(values, pa.ChunkedArray): | ||
raise ValueError | ||
|
||
assert values.type == pa.bool_() | ||
self._data = values | ||
self._dtype = ArrowBoolDtype() | ||
|
||
def __repr__(self): | ||
return "ArrowBoolArray({})".format(repr(self._data)) | ||
|
||
@classmethod | ||
def from_scalars(cls, values): | ||
arr = pa.chunked_array([pa.array(np.asarray(values))]) | ||
return cls(arr) | ||
|
||
@classmethod | ||
def from_array(cls, arr): | ||
assert isinstance(arr, pa.Array) | ||
return cls(pa.chunked_array([arr])) | ||
|
||
@classmethod | ||
def _from_sequence(cls, scalars, dtype=None, copy=False): | ||
return cls.from_scalars(scalars) | ||
|
||
def __getitem__(self, item): | ||
return self._data.to_pandas()[item] | ||
|
||
def __len__(self): | ||
return len(self._data) | ||
|
||
@property | ||
def dtype(self): | ||
return self._dtype | ||
|
||
@property | ||
def nbytes(self): | ||
return sum(x.size for chunk in self._data.chunks | ||
for x in chunk.buffers() | ||
if x is not None) | ||
|
||
def isna(self): | ||
return pd.isna(self._data.to_pandas()) | ||
|
||
def take(self, indices, allow_fill=False, fill_value=None): | ||
data = self._data.to_pandas() | ||
|
||
if allow_fill and fill_value is None: | ||
fill_value = self.dtype.na_value | ||
|
||
result = take(data, indices, fill_value=fill_value, | ||
allow_fill=allow_fill) | ||
return self._from_sequence(result, dtype=self.dtype) | ||
|
||
def copy(self, deep=False): | ||
if deep: | ||
return copy.deepcopy(self._data) | ||
else: | ||
return copy.copy(self._data) | ||
|
||
def _concat_same_type(cls, to_concat): | ||
chunks = list(itertools.chain.from_iterable(x._data.chunks | ||
for x in to_concat)) | ||
arr = pa.chunked_array(chunks) | ||
return cls(arr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy as np | ||
import pytest | ||
import pandas as pd | ||
import pandas.util.testing as tm | ||
from pandas.tests.extension import base | ||
|
||
pytest.importorskip('pyarrow', minversion="0.10.0") | ||
|
||
from .bool import ArrowBoolDtype, ArrowBoolArray | ||
|
||
|
||
@pytest.fixture | ||
def dtype(): | ||
return ArrowBoolDtype() | ||
|
||
|
||
@pytest.fixture | ||
def data(): | ||
return ArrowBoolArray.from_scalars(np.random.randint(0, 2, size=100, | ||
dtype=bool)) | ||
|
||
|
||
class BaseArrowTests(object): | ||
pass | ||
|
||
|
||
class TestDtype(BaseArrowTests, base.BaseDtypeTests): | ||
def test_array_type_with_arg(self, data, dtype): | ||
pytest.skip("GH-22666") | ||
|
||
|
||
class TestInterface(BaseArrowTests, base.BaseInterfaceTests): | ||
def test_repr(self, data): | ||
raise pytest.skip("TODO") | ||
|
||
|
||
class TestConstructors(BaseArrowTests, base.BaseConstructorsTests): | ||
def test_from_dtype(self, data): | ||
pytest.skip("GH-22666") | ||
|
||
|
||
def test_is_bool_dtype(data): | ||
assert pd.api.types.is_bool_dtype(data) | ||
assert pd.core.common.is_bool_indexer(data) | ||
s = pd.Series(range(len(data))) | ||
result = s[data] | ||
expected = s[np.asarray(data)] | ||
tm.assert_series_equal(result, expected) |