Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH/POC: ExtensionIndex for arbitrary EAs #37869

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,9 @@ def view(self, dtype: Optional[Dtype] = None) -> ArrayLike:
# - The only case that *must* be implemented is with dtype=None,
# giving a view with the same dtype as self.
if dtype is not None:
if dtype is np.ndarray:
# passed in Index.values
return np.asarray(self)
raise NotImplementedError(dtype)
return self[:]

Expand Down
26 changes: 19 additions & 7 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,10 @@ def __new__(
stacklevel=2,
)

from pandas.core.arrays import PandasArray
from pandas.core.arrays import (
PandasArray,
StringArray,
)
from pandas.core.indexes.range import RangeIndex

name = maybe_extract_name(name, data, cls)
Expand All @@ -357,7 +360,7 @@ def __new__(
validate_tz_from_dtype(dtype, tz)
dtype = tz_to_dtype(tz)

if isinstance(data, PandasArray):
if isinstance(data, PandasArray) and not isinstance(data, StringArray):
# ensure users don't accidentally put a PandasArray in an index.
data = data.to_numpy()
if isinstance(dtype, PandasDtype):
Expand All @@ -383,11 +386,13 @@ def __new__(
if klass is not Index:
return klass(data, dtype=dtype, copy=copy, name=name, **kwargs)

from pandas.core.indexes.extension import ExtensionIndex

ea_cls = dtype.construct_array_type()
data = ea_cls._from_sequence(data, dtype=dtype, copy=copy)
data = np.asarray(data, dtype=object)
disallow_kwargs(kwargs)
return Index._simple_new(data, name=name)
data = extract_array(data, extract_numpy=True)
return ExtensionIndex._simple_new(data, name=name)

elif is_ea_or_datetimelike_dtype(data_dtype):
klass = cls._dtype_to_subclass(data_dtype)
Expand All @@ -397,9 +402,15 @@ def __new__(
return result.astype(dtype, copy=False)
return result

data = np.array(data, dtype=object, copy=copy)
disallow_kwargs(kwargs)
return Index._simple_new(data, name=name)
if data_dtype == object:
data = np.array(data, dtype=object, copy=copy)
return Index._simple_new(data, name=name)

from pandas.core.indexes.extension import ExtensionIndex

data = extract_array(data, extract_numpy=True)
return ExtensionIndex._simple_new(data, name=name) # TODO: copy?

# index-like
elif isinstance(data, (np.ndarray, Index, ABCSeries)):
Expand Down Expand Up @@ -568,7 +579,8 @@ def _simple_new(cls: Type[_IndexT], values, name: Hashable = None) -> _IndexT:

Must be careful not to recurse.
"""
assert isinstance(values, np.ndarray), type(values)
if cls.__name__ != "ExtensionIndex":
assert isinstance(values, np.ndarray), type(values)

result = object.__new__(cls)
result._data = values
Expand Down
20 changes: 20 additions & 0 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np

from pandas._libs import index as libindex
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
from pandas.util._decorators import (
Expand Down Expand Up @@ -257,6 +258,25 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:

# ---------------------------------------------------------------------

@property
def _na_value(self):
return self.dtype.na_value

@property # TODO: cache_readonly?
def _engine_type(self):
# TODO: can we avoid re-calling if get_engine_target is expensive?
dtype = self._get_engine_target().dtype
return {
np.int8: libindex.Int8Engine,
np.int16: libindex.Int16Engine,
np.int32: libindex.Int32Engine,
np.int64: libindex.Int64Engine,
np.object_: libindex.ObjectEngine,
# TODO: missing floats, uints
}[dtype.type]

# ---------------------------------------------------------------------

def _get_engine_target(self) -> np.ndarray:
return np.asarray(self._data)

Expand Down
22 changes: 17 additions & 5 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,15 +1371,27 @@ def test_integer_array_add_list_like(
left = container + box_1d_array(data)
right = box_1d_array(data) + container

if Series == box_pandas_1d_array:
if box_pandas_1d_array is Series:
assert_function = tm.assert_series_equal
expected = Series(expected_data, dtype="Int64")
elif Series == box_1d_array:

elif box_1d_array is Series:
assert_function = tm.assert_series_equal
expected = Series(expected_data, dtype="object")
elif Index in (box_pandas_1d_array, box_1d_array):

if box_pandas_1d_array is Index:
expected = Series(expected_data, dtype="Int64")
else:
expected = Series(expected_data, dtype="object")

elif box_pandas_1d_array is Index:
assert_function = tm.assert_index_equal
expected = Int64Index(expected_data)
expected = Index(array(expected_data))
assert expected.dtype == "Int64"

elif box_1d_array is Index:
assert_function = tm.assert_index_equal
expected = Index(expected_data)

else:
assert_function = tm.assert_numpy_array_equal
expected = np.array(expected_data, dtype="object")
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/arrays/categorical/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,9 @@ def test_categorical_extension_array_nullable(self, nulls_fixture):
# GH:
arr = pd.arrays.StringArray._from_sequence([nulls_fixture] * 2)
result = Categorical(arr)
expected = Categorical(Series([pd.NA, pd.NA], dtype="object"))
idx = Index(arr)
assert idx.dtype == arr.dtype
expected = Categorical([np.nan, np.nan], categories=idx[:0])
tm.assert_categorical_equal(result, expected)

def test_from_sequence_copy(self):
Expand Down
5 changes: 3 additions & 2 deletions pandas/tests/arrays/integer/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def test_construct_index(all_data, dropna):
else:
other = all_data

result = pd.Index(pd.array(other, dtype=all_data.dtype))
expected = pd.Index(other, dtype=object)
arr = pd.array(other, dtype=all_data.dtype)
result = pd.Index(arr)
expected = pd.core.indexes.extension.ExtensionIndex._simple_new(arr)

tm.assert_index_equal(result, expected)

Expand Down
17 changes: 16 additions & 1 deletion pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,29 @@ def test_groupby_extension_apply(self):
we'll be able to dispatch unique.
"""

@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize(
"as_index",
[
pytest.param(
True,
marks=pytest.mark.xfail(
reason="Best guess: lack of hashability breaks ExtensionIndex"
),
),
False,
],
)
def test_groupby_extension_agg(self, as_index, data_for_grouping):
super().test_groupby_extension_agg(as_index, data_for_grouping)

@pytest.mark.xfail(reason="GH#39098: Converts agg result to object")
def test_groupby_agg_extension(self, data_for_grouping):
super().test_groupby_agg_extension(data_for_grouping)

@pytest.mark.xfail(reason="Best guess: lack of hashability breaks ExtensionIndex")
def test_groupby_extension_no_sort(self, data_for_grouping):
super().test_groupby_extension_no_sort(data_for_grouping)


class TestArithmeticOps(BaseJSON, base.BaseArithmeticOpsTests):
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
Expand Down
12 changes: 11 additions & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,14 @@ class TestPrinting(base.BasePrintingTests):


class TestGroupBy(base.BaseGroupbyTests):
pass
@pytest.fixture(
params=[
StringDtype,
pytest.param(
ArrowStringDtype, marks=td.skip_if_no("pyarrow", min_version="2.0.0")
),
]
)
def dtype(self, request):
# GH#37869 we need pyarrow 2.0+ for some of these tests
return request.param
14 changes: 8 additions & 6 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
MultiIndex,
Series,
Timestamp,
array as pd_array,
date_range,
isna,
)
import pandas._testing as tm
from pandas.core.indexes.extension import ExtensionIndex
import pandas.core.nanops as nanops
from pandas.util import _test_decorators as td

Expand Down Expand Up @@ -122,10 +124,6 @@ def test_intercept_builtin_sum():
tm.assert_series_equal(result2, expected)


# @pytest.mark.parametrize("f", [max, min, sum])
# def test_builtins_apply(f):


@pytest.mark.parametrize("f", [max, min, sum])
@pytest.mark.parametrize("keys", ["jim", ["jim", "joe"]]) # Single key # Multi-key
def test_builtins_apply(keys, f):
Expand Down Expand Up @@ -1106,7 +1104,9 @@ def test_apply_to_nullable_integer_returns_float(values, function):
# https://github.com/pandas-dev/pandas/issues/32219
output = 0.5 if function == "var" else 1.5
arr = np.array([output] * 3, dtype=float)
idx = Index([1, 2, 3], dtype=object, name="a")
idx = Index(pd_array([1, 2, 3]), name="a")
assert isinstance(idx, ExtensionIndex)
assert idx.dtype == "Int64"
expected = DataFrame({"b": arr}, index=idx).astype("Float64")

groups = DataFrame(values, dtype="Int64").groupby("a")
Expand All @@ -1126,7 +1126,9 @@ def test_groupby_sum_below_mincount_nullable_integer():
# https://github.com/pandas-dev/pandas/issues/32861
df = DataFrame({"a": [0, 1, 2], "b": [0, 1, 2], "c": [0, 1, 2]}, dtype="Int64")
grouped = df.groupby("a")
idx = Index([0, 1, 2], dtype=object, name="a")
idx = Index(pd_array([0, 1, 2]), name="a")
assert isinstance(idx, ExtensionIndex)
assert idx.dtype == "Int64"

result = grouped["b"].sum(min_count=2)
expected = Series([pd.NA] * 3, dtype="Int64", index=idx, name="b")
Expand Down