diff --git a/pandas/conftest.py b/pandas/conftest.py index 35affa62ccf68..213b29384e2db 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -479,6 +479,7 @@ def _create_mi_with_dt64tz_level(): "mi-with-dt64tz-level": _create_mi_with_dt64tz_level(), "multi": _create_multiindex(), "repeats": Index([0, 0, 1, 1, 2, 2]), + "nullable_int": Index(np.arange(100), dtype="Int64"), } diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 6f906cf8879ff..0be0a65726c82 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1512,7 +1512,10 @@ def take( ) else: # NumPy style - result = arr.take(indices, axis=axis) + if arr.ndim == 1: + result = arr.take(indices) + else: + result = arr.take(indices, axis=axis) return result diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 58f5ca3de5dce..0051cc1ab764b 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -424,7 +424,6 @@ def __new__( ea_cls = dtype.construct_array_type() data = ea_cls._from_sequence(data, dtype=dtype, copy=copy) - data = np.asarray(data, dtype=object) disallow_kwargs(kwargs) return Index._simple_new(data, name=name) @@ -436,7 +435,7 @@ def __new__( return result.astype(dtype, copy=False) return result - data = np.array(data, dtype=object, copy=copy) + data = extract_array(data) disallow_kwargs(kwargs) return Index._simple_new(data, name=name) @@ -612,7 +611,7 @@ def _simple_new(cls: type[_IndexT], values, name: Hashable = None) -> _IndexT: Must be careful not to recurse. """ - assert isinstance(values, np.ndarray), type(values) + assert isinstance(values, (np.ndarray, ExtensionArray)), type(values) result = object.__new__(cls) result._data = values @@ -2572,6 +2571,9 @@ def fillna(self, value=None, downcast=None): Series.fillna : Fill NaN Values of a Series. """ value = self._require_scalar(value) + if is_extension_array_dtype(self.dtype) and type(self) is Index: + return self._shallow_copy(self._values.fillna(value)) + if self.hasnans: result = self.putmask(self._isnan, value) if downcast is None: @@ -4348,7 +4350,7 @@ def values(self) -> ArrayLike: Index.array : Reference to the underlying data. Index.to_numpy : A NumPy array representing the underlying data. """ - return self._data + return self._data # .view(np.ndarray) @cache_readonly @doc(IndexOpsMixin.array) @@ -4390,9 +4392,7 @@ def _get_engine_target(self) -> np.ndarray: """ Get the ndarray that we can pass to the IndexEngine constructor. """ - # error: Incompatible return value type (got "Union[ExtensionArray, - # ndarray]", expected "ndarray") - return self._values # type: ignore[return-value] + return np.asarray(self._values) def _get_join_target(self) -> np.ndarray: """ @@ -4405,6 +4405,9 @@ def _from_join_target(self, result: np.ndarray) -> ArrayLike: Cast the ndarray returned from one of the libjoin.foo_indexer functions back to type(self)._data. """ + if is_extension_array_dtype(self.dtype): + # TODO use helper method / strict version + return self._values._from_sequence(result, dtype=self.dtype) return result @doc(IndexOpsMixin._memory_usage) @@ -4782,10 +4785,18 @@ def equals(self, other: Any) -> bool: # d-level MultiIndex can equal d-tuple Index return other.equals(self) - if is_extension_array_dtype(other.dtype): + if is_extension_array_dtype(other.dtype) and type(other) != Index: # All EA-backed Index subclasses override equals return other.equals(self) + if is_extension_array_dtype(self.dtype): + if is_object_dtype(other.dtype): + try: + other = other.astype(self.dtype) + except Exception: + return False + return self._values.equals(other._values) + return array_equivalent(self._values, other._values) @final @@ -5470,6 +5481,15 @@ def map(self, mapper, na_action=None): attributes = self._get_attributes_dict() + if is_extension_array_dtype(self.dtype): + # try to coerce back to original dtype + # TODO this should use a strict version + try: + # TODO use existing helper method for this + new_values = self._values._from_sequence(new_values, dtype=self.dtype) + except Exception: + pass + # we can return a MultiIndex if new_values.size and isinstance(new_values[0], tuple): if isinstance(self, MultiIndex): @@ -5906,8 +5926,10 @@ def delete(self, loc) -> Index: >>> idx.delete([0, 2]) Index(['b'], dtype='object') """ - res_values = np.delete(self._data, loc) - return type(self)._simple_new(res_values, name=self.name) + # this is currently overridden by EA-based Index subclasses + keep = np.ones(len(self), dtype=bool) + keep[loc] = False + return type(self)._simple_new(self._data[keep], name=self.name) def insert(self, loc: int, item) -> Index: """ @@ -5937,11 +5959,15 @@ def insert(self, loc: int, item) -> Index: dtype = find_common_type([self.dtype, inferred]) return self.astype(dtype).insert(loc, item) - arr = np.asarray(self) - # Use Index constructor to ensure we get tuples cast correctly. item = Index([item], dtype=self.dtype)._values - idx = np.concatenate((arr[:loc], item, arr[loc:])) + + arr = self._values + if is_extension_array_dtype(self.dtype): + idx = arr._concat_same_type([arr[:loc], item, arr[loc:]]) + else: + idx = np.concatenate((arr[:loc], item, arr[loc:])) + return Index(idx, name=self.name) def drop(self, labels, errors: str_t = "raise") -> Index: diff --git a/pandas/tests/arrays/integer/test_dtypes.py b/pandas/tests/arrays/integer/test_dtypes.py index e3f59205aa07c..9b226846424f5 100644 --- a/pandas/tests/arrays/integer/test_dtypes.py +++ b/pandas/tests/arrays/integer/test_dtypes.py @@ -72,7 +72,7 @@ def test_construct_index(all_data, dropna): other = all_data result = pd.Index(pd.array(other, dtype=all_data.dtype)) - expected = pd.Index(other, dtype=object) + expected = pd.Index(other, dtype=all_data.dtype) tm.assert_index_equal(result, expected) diff --git a/pandas/tests/base/test_value_counts.py b/pandas/tests/base/test_value_counts.py index 4151781f0dbf5..9511eb4bc74c4 100644 --- a/pandas/tests/base/test_value_counts.py +++ b/pandas/tests/base/test_value_counts.py @@ -34,6 +34,8 @@ def test_value_counts(index_or_series_obj): expected.index = expected.index.astype(obj.dtype) if isinstance(obj, pd.MultiIndex): expected.index = Index(expected.index) + if isinstance(obj.dtype, pd.Int64Dtype): + expected = expected.astype("Int64") # TODO: Order of entries with the same count is inconsistent on CI (gh-32449) if obj.duplicated().any(): @@ -69,6 +71,8 @@ def test_value_counts_null(null_obj, index_or_series_obj): counter = collections.Counter(obj.dropna()) expected = Series(dict(counter.most_common()), dtype=np.int64) expected.index = expected.index.astype(obj.dtype) + if isinstance(obj.dtype, pd.Int64Dtype): + expected = expected.astype("Int64") result = obj.value_counts() if obj.duplicated().any(): diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index f47fc1f4e4a4f..701b0302fe275 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -1119,7 +1119,7 @@ def test_apply_to_nullable_integer_returns_float(values, function): # https://github.com/pandas-dev/pandas/issues/32219 output = 0.5 if function == "var" else 1.5 arr = np.array([output] * 3, dtype=float) - idx = Index([1, 2, 3], dtype=object, name="a") + idx = Index([1, 2, 3], dtype="Int64", name="a") expected = DataFrame({"b": arr}, index=idx).astype("Float64") groups = DataFrame(values, dtype="Int64").groupby("a") diff --git a/pandas/tests/indexes/common.py b/pandas/tests/indexes/common.py index ab2b2db7eec53..4fcf41127c61f 100644 --- a/pandas/tests/indexes/common.py +++ b/pandas/tests/indexes/common.py @@ -6,7 +6,10 @@ from pandas._libs import iNaT -from pandas.core.dtypes.common import is_datetime64tz_dtype +from pandas.core.dtypes.common import ( + is_datetime64tz_dtype, + is_extension_array_dtype, +) from pandas.core.dtypes.dtypes import CategoricalDtype import pandas as pd @@ -271,6 +274,9 @@ def test_ensure_copied_data(self, index): elif isinstance(index, IntervalIndex): # checked in test_interval.py pass + elif is_extension_array_dtype(index.dtype): + # TODO can we check this generally? + pass else: result = index_type(index.values, copy=False, **init_kwargs) tm.assert_numpy_array_equal(index.values, result.values, check_same="same") diff --git a/pandas/tests/indexes/test_base.py b/pandas/tests/indexes/test_base.py index 1e9348dc410d7..2b56b4b0e7878 100644 --- a/pandas/tests/indexes/test_base.py +++ b/pandas/tests/indexes/test_base.py @@ -717,6 +717,10 @@ def test_map_dictlike(self, index, mapper): else: expected = Index(np.arange(len(index), 0, -1)) + if isinstance(index.dtype, pd.Int64Dtype): + # map tries to preserve the nullable dtype + expected = expected.astype("Int64") + result = index.map(mapper(expected, index)) tm.assert_index_equal(result, expected) diff --git a/pandas/tests/indexing/test_indexing.py b/pandas/tests/indexing/test_indexing.py index df688d6745096..6eca46f4b9733 100644 --- a/pandas/tests/indexing/test_indexing.py +++ b/pandas/tests/indexing/test_indexing.py @@ -96,6 +96,7 @@ def test_getitem_ndarray_3d( msgs.append("Index data must be 1-dimensional") if len(index) == 0 or isinstance(index, pd.MultiIndex): msgs.append("positional indexers are out-of-bounds") + msgs.append("values must be a 1D array") msg = "|".join(msgs) potential_errors = (IndexError, ValueError, NotImplementedError)