From 4d183b57b106ed37b0fdf985eb9c30cf6e4643ec Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 15 Nov 2020 11:34:52 -0800 Subject: [PATCH 1/3] POC: ExtensionIndex --- pandas/core/arrays/base.py | 3 +++ pandas/core/indexes/base.py | 18 ++++++++++---- pandas/core/indexes/extension.py | 18 ++++++++++++++ pandas/tests/arithmetic/test_numeric.py | 24 ++++++++++++++----- .../arrays/categorical/test_constructors.py | 4 +++- pandas/tests/arrays/integer/test_dtypes.py | 5 ++-- pandas/tests/groupby/test_function.py | 20 +++++++++++++--- 7 files changed, 75 insertions(+), 17 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 0968545a6b8a4..aefdc77310b13 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1058,6 +1058,9 @@ def view(self, dtype=None) -> ArrayLike: # - The only case that *must* be implemented is with dtype=None, # giving a view with the same dtype as self. if dtype is not None: + if dtype is np.ndarray: + # passed in Index.values + return np.asarray(self) raise NotImplementedError(dtype) return self[:] diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index cb5641a74e60b..245ca3d858eb5 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -313,11 +313,18 @@ def __new__( ea_cls = dtype.construct_array_type() data = ea_cls._from_sequence(data, dtype=dtype, copy=False) else: - data = np.asarray(data, dtype=object) + data = extract_array(data, extract_numpy=True) + if type(data).__name__ == "PandasArray": + # We're doing the test that patches PandasArray to not be + # recognized as EA + data = data._ndarray + return Index(data, dtype=object, copy=copy, name=name, **kwargs) - # coerce to the object dtype - data = data.astype(object) - return Index(data, dtype=object, copy=copy, name=name, **kwargs) + from pandas.core.indexes.extension import ExtensionIndex + + obj = ExtensionIndex._simple_new(data, name=name) + # TODO: need to handle maybe_asobject + return obj # index-like elif isinstance(data, (np.ndarray, Index, ABCSeries)): @@ -425,7 +432,8 @@ def _simple_new(cls, values, name: Label = None): Must be careful not to recurse. """ - assert isinstance(values, np.ndarray), type(values) + if cls.__name__ != "ExtensionIndex": + assert isinstance(values, np.ndarray), type(values) result = object.__new__(cls) result._data = values diff --git a/pandas/core/indexes/extension.py b/pandas/core/indexes/extension.py index 4d09a97b18eed..1c2875481da22 100644 --- a/pandas/core/indexes/extension.py +++ b/pandas/core/indexes/extension.py @@ -5,6 +5,7 @@ import numpy as np +from pandas._libs import index as libindex from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly, doc @@ -230,6 +231,23 @@ def __getitem__(self, key): # --------------------------------------------------------------------- + @property + def _na_value(self): + return self.dtype.na_value + + @property # TODO: cache_readonly? + def _engine_type(self): + # TODO: can we avoid re-calling if get_engine_target is expensive? + dtype = self._get_engine_target().dtype + return { + np.int8: libindex.Int8Engine, + np.int16: libindex.Int16Engine, + np.int32: libindex.Int32Engine, + np.int64: libindex.Int64Engine, + np.object_: libindex.ObjectEngine, + # TODO: missing floats, uints + }[dtype.type] + def _get_engine_target(self) -> np.ndarray: return np.asarray(self._data) diff --git a/pandas/tests/arithmetic/test_numeric.py b/pandas/tests/arithmetic/test_numeric.py index 836b1dcddf0dd..9a39c311fa337 100644 --- a/pandas/tests/arithmetic/test_numeric.py +++ b/pandas/tests/arithmetic/test_numeric.py @@ -11,7 +11,7 @@ import pytest import pandas as pd -from pandas import Index, Int64Index, Series, Timedelta, TimedeltaIndex, array +from pandas import Index, Series, Timedelta, TimedeltaIndex, array import pandas._testing as tm from pandas.core import ops @@ -1358,15 +1358,27 @@ def test_integer_array_add_list_like( left = container + box_1d_array(data) right = box_1d_array(data) + container - if Series == box_pandas_1d_array: + if box_pandas_1d_array is Series: assert_function = tm.assert_series_equal expected = Series(expected_data, dtype="Int64") - elif Series == box_1d_array: + + elif box_1d_array is Series: assert_function = tm.assert_series_equal - expected = Series(expected_data, dtype="object") - elif Index in (box_pandas_1d_array, box_1d_array): + + if box_pandas_1d_array is Index: + expected = Series(expected_data, dtype="Int64") + else: + expected = Series(expected_data, dtype="object") + + elif box_pandas_1d_array is Index: assert_function = tm.assert_index_equal - expected = Int64Index(expected_data) + expected = Index(array(expected_data)) + assert expected.dtype == "Int64" + + elif box_1d_array is Index: + assert_function = tm.assert_index_equal + expected = Index(expected_data) + else: assert_function = tm.assert_numpy_array_equal expected = np.array(expected_data, dtype="object") diff --git a/pandas/tests/arrays/categorical/test_constructors.py b/pandas/tests/arrays/categorical/test_constructors.py index 23921356a2c5d..b0909cb5e0471 100644 --- a/pandas/tests/arrays/categorical/test_constructors.py +++ b/pandas/tests/arrays/categorical/test_constructors.py @@ -687,5 +687,7 @@ def test_categorical_extension_array_nullable(self, nulls_fixture): # GH: arr = pd.arrays.StringArray._from_sequence([nulls_fixture] * 2) result = Categorical(arr) - expected = Categorical(Series([pd.NA, pd.NA], dtype="object")) + idx = Index(arr) + assert idx.dtype == arr.dtype + expected = Categorical([np.nan, np.nan], categories=idx[:0]) tm.assert_categorical_equal(result, expected) diff --git a/pandas/tests/arrays/integer/test_dtypes.py b/pandas/tests/arrays/integer/test_dtypes.py index d71037f9151e0..3d30f0bd1e571 100644 --- a/pandas/tests/arrays/integer/test_dtypes.py +++ b/pandas/tests/arrays/integer/test_dtypes.py @@ -69,8 +69,9 @@ def test_construct_index(all_data, dropna): else: other = all_data - result = pd.Index(integer_array(other, dtype=all_data.dtype)) - expected = pd.Index(other, dtype=object) + arr = integer_array(other, dtype=all_data.dtype) + result = pd.Index(arr) + expected = pd.core.indexes.extension.ExtensionIndex._simple_new(arr) tm.assert_index_equal(result, expected) diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index e49e69a39b315..3ea3a85650f3f 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -7,8 +7,18 @@ from pandas.errors import UnsupportedFunctionCall import pandas as pd -from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range, isna +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + array as pd_array, + date_range, + isna, +) import pandas._testing as tm +from pandas.core.indexes.extension import ExtensionIndex import pandas.core.nanops as nanops from pandas.util import _test_decorators as td @@ -1027,7 +1037,9 @@ def test_apply_to_nullable_integer_returns_float(values, function): # https://github.com/pandas-dev/pandas/issues/32219 output = 0.5 if function == "var" else 1.5 arr = np.array([output] * 3, dtype=float) - idx = Index([1, 2, 3], dtype=object, name="a") + idx = Index(pd_array([1, 2, 3]), name="a") + assert isinstance(idx, ExtensionIndex) + assert idx.dtype == "Int64" expected = DataFrame({"b": arr}, index=idx) groups = DataFrame(values, dtype="Int64").groupby("a") @@ -1047,7 +1059,9 @@ def test_groupby_sum_below_mincount_nullable_integer(): # https://github.com/pandas-dev/pandas/issues/32861 df = DataFrame({"a": [0, 1, 2], "b": [0, 1, 2], "c": [0, 1, 2]}, dtype="Int64") grouped = df.groupby("a") - idx = Index([0, 1, 2], dtype=object, name="a") + idx = Index(pd_array([0, 1, 2]), name="a") + assert isinstance(idx, ExtensionIndex) + assert idx.dtype == "Int64" result = grouped["b"].sum(min_count=2) expected = Series([pd.NA] * 3, dtype="Int64", index=idx, name="b") From 3950682c086b8e879828d1d523b5ffc081e0e8ed Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 17 Feb 2021 09:40:51 -0800 Subject: [PATCH 2/3] typo fixup --- pandas/tests/extension/test_string.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 952ec9ea0e758..4f824354d8970 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -35,7 +35,7 @@ ] ) def dtype(request): - return request.param() + return request.param @pytest.fixture @@ -173,4 +173,4 @@ class TestGroupBy(base.BaseGroupbyTests): ) def dtype(request): # GH#37869 we need pyarrow 2.0+ for some of these tests - return request.param() + return request.param From 5da38cc2d3a836e6f68e43c1cc5ac61bfa870ec0 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 4 Mar 2021 20:18:26 -0800 Subject: [PATCH 3/3] fix StringArray tests --- pandas/tests/extension/test_string.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 454cee46b09ab..9d739ffb39c10 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -58,7 +58,7 @@ def chunked(request): ] ) def dtype(request): - return request.param + return request.param() @pytest.fixture @@ -199,6 +199,6 @@ class TestGroupBy(base.BaseGroupbyTests): ), ] ) - def dtype(request): + def dtype(self, request): # GH#37869 we need pyarrow 2.0+ for some of these tests return request.param