Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: allow storing ExtensionArrays in the Index #34159

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def _create_mi_with_dt64tz_level():
"mi-with-dt64tz-level": _create_mi_with_dt64tz_level(),
"multi": _create_multiindex(),
"repeats": Index([0, 0, 1, 1, 2, 2]),
"nullable_int": Index(np.arange(100), dtype="Int64"),
}


Expand Down
58 changes: 46 additions & 12 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,13 @@ def __new__(
ea_cls = dtype.construct_array_type()
data = ea_cls._from_sequence(data, dtype=dtype, copy=False)
else:
data = np.asarray(data, dtype=object)
# TODO clean-up with extract_array ?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

if isinstance(data, Index):
data = data._data
elif isinstance(data, ABCSeries):
data = data.array

# coerce to the object dtype
data = data.astype(object)
return Index(data, dtype=object, copy=copy, name=name, **kwargs)
return cls._simple_new(data, name)

# index-like
elif isinstance(data, (np.ndarray, Index, ABCSeries)):
Expand Down Expand Up @@ -458,7 +460,7 @@ def _simple_new(cls, values, name: Label = None):

Must be careful not to recurse.
"""
assert isinstance(values, np.ndarray), type(values)
assert isinstance(values, (np.ndarray, ExtensionArray)), type(values)

result = object.__new__(cls)
result._data = values
Expand Down Expand Up @@ -2126,6 +2128,8 @@ def fillna(self, value=None, downcast=None):
Series.fillna : Fill NaN Values of a Series.
"""
self._assert_can_do_op(value)
if is_extension_array_dtype(self.dtype):
return self._shallow_copy(self._values.fillna(value))
if self.hasnans:
result = self.putmask(self._isnan, value)
if downcast is None:
Expand Down Expand Up @@ -2525,7 +2529,9 @@ def _union(self, other, sort):
# worth making this faster? a very unusual case
value_set = set(lvals)
result.extend([x for x in rvals if x not in value_set])
result = Index(result)._values # do type inference here
result = Index(
result, dtype=self.dtype
)._values # do type inference here
else:
# find indexes of things in "other" that are not in "self"
if self.is_unique:
Expand Down Expand Up @@ -3797,7 +3803,7 @@ def values(self) -> np.ndarray:
Index.array : Reference to the underlying data.
Index.to_numpy : A NumPy array representing the underlying data.
"""
return self._data.view(np.ndarray)
return self._data # .view(np.ndarray)

@cache_readonly
@doc(IndexOpsMixin.array)
Expand Down Expand Up @@ -3839,7 +3845,10 @@ def _get_engine_target(self) -> np.ndarray:
"""
Get the ndarray that we can pass to the IndexEngine constructor.
"""
return self._values
if isinstance(self._values, np.ndarray):
return self._values
else:
return np.asarray(self._values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt np.asarray(self._values) work unconditionally?


@doc(IndexOpsMixin.memory_usage)
def memory_usage(self, deep: bool = False) -> int:
Expand Down Expand Up @@ -4232,10 +4241,18 @@ def equals(self, other: Any) -> bool:
# d-level MultiIndex can equal d-tuple Index
return other.equals(self)

if is_extension_array_dtype(other.dtype):
if is_extension_array_dtype(other.dtype) and type(other) != Index:
# All EA-backed Index subclasses override equals
return other.equals(self)

if is_extension_array_dtype(self.dtype):
if is_object_dtype(other.dtype):
try:
other = other.astype(self.dtype)
except Exception:
return False
return self._values.equals(other._values)

return array_equivalent(self._values, other._values)

def identical(self, other) -> bool:
Expand Down Expand Up @@ -4759,6 +4776,15 @@ def map(self, mapper, na_action=None):

attributes = self._get_attributes_dict()

if is_extension_array_dtype(self.dtype):
# try to coerce back to original dtype
# TODO this should use a strict version
try:
# TODO use existing helper method for this
new_values = self._values._from_sequence(new_values, dtype=self.dtype)
except Exception:
pass

# we can return a MultiIndex
if new_values.size and isinstance(new_values[0], tuple):
if isinstance(self, MultiIndex):
Expand Down Expand Up @@ -5193,7 +5219,10 @@ def delete(self, loc):
>>> idx.delete([0, 2])
Index(['b'], dtype='object')
"""
return self._shallow_copy(np.delete(self._data, loc))
# this is currently overriden by EA-based Index subclasses
keep = np.ones(len(self), dtype=bool)
keep[loc] = False
return self._shallow_copy(self._data[keep])

def insert(self, loc: int, item):
"""
Expand All @@ -5212,9 +5241,14 @@ def insert(self, loc: int, item):
"""
# Note: this method is overridden by all ExtensionIndex subclasses,
# so self is never backed by an EA.
arr = np.asarray(self)
item = self._coerce_scalar_to_index(item)._values
idx = np.concatenate((arr[:loc], item, arr[loc:]))

if is_extension_array_dtype(self.dtype):
arr = self._values
idx = arr._concat_same_type([arr[:loc], item, arr[loc:]])
else:
arr = np.asarray(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arr = self._values could go outside the if/elif

idx = np.concatenate((arr[:loc], item, arr[loc:]))
return Index(idx, name=self.name)

def drop(self, labels, errors: str_t = "raise"):
Expand Down
1 change: 0 additions & 1 deletion pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,6 @@ def isetter(loc, v):
):
self.obj[item_labels[indexer[info_axis]]] = value
return

indexer = maybe_convert_ix(*indexer)

if isinstance(value, (ABCSeries, dict)):
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/base/test_value_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_value_counts(index_or_series_obj):
expected.index = expected.index.astype(obj.dtype)
if isinstance(obj, pd.MultiIndex):
expected.index = pd.Index(expected.index)
if isinstance(obj.dtype, pd.Int64Dtype):
expected = expected.astype("Int64")

# TODO: Order of entries with the same count is inconsistent on CI (gh-32449)
if obj.duplicated().any():
Expand Down Expand Up @@ -69,6 +71,8 @@ def test_value_counts_null(null_obj, index_or_series_obj):
counter = collections.Counter(obj.dropna())
expected = pd.Series(dict(counter.most_common()), dtype=np.int64)
expected.index = expected.index.astype(obj.dtype)
if isinstance(obj.dtype, pd.Int64Dtype):
expected = expected.astype("Int64")

result = obj.value_counts()
if obj.duplicated().any():
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def test_apply_to_nullable_integer_returns_float(values, function):
# https://github.com/pandas-dev/pandas/issues/32219
output = 0.5 if function == "var" else 1.5
arr = np.array([output] * 3, dtype=float)
idx = pd.Index([1, 2, 3], dtype=object, name="a")
idx = pd.Index([1, 2, 3], dtype="Int64", name="a")
expected = pd.DataFrame({"b": arr}, index=idx)

groups = pd.DataFrame(values, dtype="Int64").groupby("a")
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pandas._libs import iNaT

from pandas.core.dtypes.common import is_datetime64tz_dtype
from pandas.core.dtypes.common import is_datetime64tz_dtype, is_extension_array_dtype
from pandas.core.dtypes.dtypes import CategoricalDtype

import pandas as pd
Expand Down Expand Up @@ -278,6 +278,9 @@ def test_ensure_copied_data(self, indices):
elif isinstance(indices, IntervalIndex):
# checked in test_interval.py
pass
elif is_extension_array_dtype(indices.dtype):
# TODO can we check this generally?
pass
else:
result = index_type(indices.values, copy=False, **init_kwargs)
tm.assert_numpy_array_equal(
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,10 @@ def test_map_dictlike(self, indices, mapper):
else:
expected = Index(np.arange(len(indices), 0, -1))

if isinstance(indices.dtype, pd.Int64Dtype):
# map tries to preserve the nullable dtype
expected = expected.astype("Int64")

result = indices.map(mapper(expected, indices))
tm.assert_index_equal(result, expected)

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/indexing/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_getitem_ndarray_3d(self, indices, obj, idxr, idxr_id):
"Index data must be 1-dimensional",
"positional indexers are out-of-bounds",
"Indexing a MultiIndex with a multidimensional key is not implemented",
"values must be a 1D array",
]
)

Expand Down