From ed03c751e1eac1a00a7519abd6abed9be92031c8 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 19 Jun 2019 16:11:13 -0700 Subject: [PATCH 1/3] implement ReshapeWrapper --- pandas/core/arrays/base.py | 83 +++++++++++++++++++++++ pandas/tests/extension/arrow/bool.py | 9 ++- pandas/tests/extension/arrow/test_bool.py | 65 ++++++++++++++++++ 3 files changed, 156 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c709cd9e9f0b2..ee7732f0f276b 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1119,3 +1119,86 @@ def _create_arithmetic_method(cls, op): @classmethod def _create_comparison_method(cls, op): return cls._create_method(op, coerce_to_dtype=False) + + +class ReshapeWrapper: + """ + Mixin to an natively-1D ExtensionArray to support limited 2D operations. + + Notes + ----- + Assumes that `type(self)(self)` will construct an array equivalent to self. + """ + # -------------------------------------------------------------- + # Override ExtensionArray, assumes ReshapeWrapper is mixed in such that + # it is before EA in the __mro__ + + shape = None # Need to override property in base class + + def __len__(self) -> int: + return self.shape[0] + + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def size(self) -> int: + return np.prod(self.shape) + + # -------------------------------------------------------------- + + def __init__(self, *args, shape=None, **kwargs): + super().__init__(*args, **kwargs) + if shape is None: + shape = (super().__len__(),) + + # If we ever allow greater dimensions, other methods below will need + # to be updated. + assert len(shape) in [1, 2] + self.shape = shape # No other validation on shape + + def reshape(self, *args): + # *args because numpy accepts either + # `arr.reshape(1, 2)` or `arr.reshape((1, 2))` + shape = args + if len(args) == 1 and isinstance(args[0], tuple): + shape = args[0] + if -1 in shape: + if shape.count(-1) != 1: + raise ValueError("Invalid shape {shape}".format(shape=shape)) + idx = shape.index(-1) + others = [n for n in shape if n != -1] + prod = np.prod(others) + dim, rem = divmod(self.size, prod) + shape = shape[:idx] + (dim,) + shape[idx+1:] + + if np.prod(shape) != self.size: + raise ValueError("Product of shape ({shape}) must match " + "size ({size})".format(shape=shape, + size=self.size)) + return type(self)(self, shape=shape) + + def transpose(self, *args, **kwargs): + # Note: no validation on args/kwargs + shape = self.shape[::-1] + return type(self)(self, shape=shape) + + @property + def T(self): + return self.transpose() + + def ravel(self, order=None): + # Note: we ignore `order`, keep the argument for compat with + # numpy signature + slen = super().__len__() + shape = (slen,) + return type(self)(self, shape=shape) + + def swapaxes(self, axis1, axis2): + if axis1 > self.ndim or axis2 > self.ndim: + raise ValueError("invalid axes=({axis1}, {axis2}) for ndim={ndim}" + .format(axis1, axis2=axis2, ndim=self.ndim)) + if self.ndim == 1 or axis1 == axis2: + return type(self)(self, shape=self.shape) + return type(self)(self, shape=self.shape[::-1]) diff --git a/pandas/tests/extension/arrow/bool.py b/pandas/tests/extension/arrow/bool.py index 435ea4e3ec2b5..83c31fd90ae16 100644 --- a/pandas/tests/extension/arrow/bool.py +++ b/pandas/tests/extension/arrow/bool.py @@ -14,6 +14,7 @@ import pandas as pd from pandas.api.extensions import ( ExtensionArray, ExtensionDtype, register_extension_dtype, take) +from pandas.core.arrays.base import ReshapeWrapper @register_extension_dtype @@ -40,8 +41,10 @@ def _is_boolean(self): return True -class ArrowBoolArray(ExtensionArray): +class _ArrowBoolArray(ExtensionArray): def __init__(self, values): + if isinstance(values, _ArrowBoolArray): + values = values._data if not isinstance(values, pa.ChunkedArray): raise ValueError @@ -142,3 +145,7 @@ def any(self, axis=0, out=None): def all(self, axis=0, out=None): return self._data.to_pandas().all() + + +class ArrowBoolArray(ReshapeWrapper, _ArrowBoolArray): + pass diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index 01163064b0918..749dac4dcaa30 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -66,3 +66,68 @@ def test_is_bool_dtype(data): result = s[data] expected = s[np.asarray(data)] tm.assert_series_equal(result, expected) + + +class TestReshape: + + @pytest.mark.parametrize('shape', [(1, -1), (-1, 1)]) + def test_ravel(self, data, shape): + dim2 = data.reshape(shape) + assert dim2.ndim == 2 + + assert dim2.ravel().shape == data.shape + + def test_transpose(self, data): + assert data.T.shape == data.shape + + rowlike = data.reshape(-1, 1) + collike = data.reshape(1, -1) + + rt1 = rowlike.T + rt2 = rowlike.transpose() + + ct1 = collike.T + ct2 = collike.transpose() + + assert ct1.shape == ct2.shape == rowlike.shape + assert rt1.shape == rt2.shape == collike.shape + + def test_swapaxes(self, data): + rowlike = data.reshape(-1, 1) + collike = data.reshape(1, -1) + + for arr in [data, rowlike, collike]: + assert arr.swapaxes(0, 0).shape == arr.shape + + assert rowlike.swapaxes(0, 1).shape == collike.shape + assert rowlike.swapaxes(1, 0).shape == collike.shape + assert collike.swapaxes(0, 1).shape == rowlike.shape + assert collike.swapaxes(1, 0).shape == rowlike.shape + + def test_reshape(self, data): + size = len(data) + + collike1 = data.reshape(1, -1) + collike2 = data.reshape((1, -1)) + collike3 = data.reshape(1, size) + collike4 = data.reshape((1, size)) + collike5 = collike1.reshape(collike1.shape) + for collike in [collike1, collike2, collike3, collike4, collike5]: + assert collike.shape == (1, size) + assert collike.ndim == 2 + + rowlike1 = data.reshape(-1, 1) + rowlike2 = data.reshape((-1, 1)) + rowlike3 = data.reshape(size, 1) + rowlike4 = data.reshape((size, 1)) + rowlike5 = rowlike1.reshape(rowlike1.shape) + rowlike6 = collike1.reshape(-1, 1) + for collike in [rowlike1, rowlike2, rowlike3, rowlike4, + rowlike5, rowlike6]: + assert collike.shape == (size, 1) + assert collike.ndim == 2 + + with pytest.raises(ValueError, match="Invalid shape"): + data.reshape(-1, -1) + with pytest.raises(ValueError, match="Product of shape"): + data.reshape(len(data), 2) From c48be14d7f6a02c0a719156e33e491121f7e1764 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 19 Jun 2019 17:48:06 -0700 Subject: [PATCH 2/3] wrap getitem --- pandas/core/arrays/base.py | 39 ++++++++++++++++++++++- pandas/tests/extension/arrow/test_bool.py | 32 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index ee7732f0f276b..f0fd1b1b4a474 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1171,7 +1171,7 @@ def reshape(self, *args): others = [n for n in shape if n != -1] prod = np.prod(others) dim, rem = divmod(self.size, prod) - shape = shape[:idx] + (dim,) + shape[idx+1:] + shape = shape[:idx] + (dim,) + shape[idx + 1:] if np.prod(shape) != self.size: raise ValueError("Product of shape ({shape}) must match " @@ -1202,3 +1202,40 @@ def swapaxes(self, axis1, axis2): if self.ndim == 1 or axis1 == axis2: return type(self)(self, shape=self.shape) return type(self)(self, shape=self.shape[::-1]) + + # -------------------------------------------------------------- + + def __getitem__(self, key): + if self.ndim == 1: + return super().__getitem__(key) + + if not isinstance(key, tuple) or len(key) != 2: + raise ValueError("Only limited selection is available for 2D " + "{cls}. Indexer must be a 2-tuple." + .format(cls=type(self).__name__)) + + dummy_dim = 0 if self.shape[0] == 1 else 1 + oth = 1 - dummy_dim + + key1 = key[dummy_dim] + key2 = key[oth] + squeezed = self.ravel() + + if isinstance(key1, int) and key1 == 0: + # dummy-dim selection reduces dimension + return squeezed[key2] + if isinstance(key1, slice): + res1d = squeezed[key2] + result = res1d.reshape(1, -1) + if dummy_dim == 1: + result = result.T + return result + + raise NotImplementedError + + def take(self, indices: Sequence[int], allow_fill: bool = False, + fill_value: Any = None) -> ABCExtensionArray: + if self.ndim == 1: + return super().take(indices=indices, allow_fill=allow_fill, + fill_value=fill_value) + raise NotImplementedError diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index 749dac4dcaa30..c9799b0bb1291 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -131,3 +131,35 @@ def test_reshape(self, data): data.reshape(-1, -1) with pytest.raises(ValueError, match="Product of shape"): data.reshape(len(data), 2) + + +class Test2D: + def test_getitem_collike(self, data): + collike = data.reshape(1, -1) + + result = collike[:, :] + assert result.shape == (1, 100) + + result = collike[:, ::2] + assert result.shape == (1, 50) + + result = collike[0, :] + assert result.shape == (100,) + + result = collike[0, ::2] + assert result.shape == (50,) + + def test_getitem_rowlike(self, data): + rowlike = data.reshape(-1, 1) + + result = rowlike[:, :] + assert result.shape == (100, 1) + + result = rowlike[::2, :] + assert result.shape == (50, 1) + + result = rowlike[:, 0] + assert result.shape == (100,) + + result = rowlike[::2, 0] + assert result.shape == (50,) From b22f86650dd503fd0de00762cb2ab171cc2b5c2a Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 19 Jun 2019 18:50:01 -0700 Subject: [PATCH 3/3] implement take --- pandas/core/arrays/base.py | 33 ++++++++++++++++++++--- pandas/tests/extension/arrow/bool.py | 5 +++- pandas/tests/extension/arrow/test_bool.py | 28 +++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index f0fd1b1b4a474..7a300cd5c7992 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1221,11 +1221,16 @@ def __getitem__(self, key): key2 = key[oth] squeezed = self.ravel() + # Note: we use type(self) explicitly below because + # Arrow bool example otherwise returns numpy if isinstance(key1, int) and key1 == 0: # dummy-dim selection reduces dimension - return squeezed[key2] + res = squeezed[key2] + if np.ndim(res) == 0: + return res + return type(self)(res) if isinstance(key1, slice): - res1d = squeezed[key2] + res1d = type(self)(squeezed[[key2]]) result = res1d.reshape(1, -1) if dummy_dim == 1: result = result.T @@ -1234,8 +1239,28 @@ def __getitem__(self, key): raise NotImplementedError def take(self, indices: Sequence[int], allow_fill: bool = False, - fill_value: Any = None) -> ABCExtensionArray: + fill_value: Any = None, axis: int = 0) -> ABCExtensionArray: if self.ndim == 1: return super().take(indices=indices, allow_fill=allow_fill, fill_value=fill_value) - raise NotImplementedError + + def take_item(n): + if axis == 0: + if n == -1: + seq = [fill_value] * self.shape[1] + return type(self)._from_sequence(seq) + else: + return self[n, :] + if n == -1: + seq = [fill_value] * self.shape[1] + return type(self)._from_sequence(seq) + else: + return self[:, n] + + arrs = [take_item(n) for n in indices] + res1d = type(self)._concat_same_type(arrs) + if axis == 0: + result = res1d.reshape(-1, 1) + else: + result = res1d.reshape(1, -1) + return result diff --git a/pandas/tests/extension/arrow/bool.py b/pandas/tests/extension/arrow/bool.py index 83c31fd90ae16..bef60042f8e6e 100644 --- a/pandas/tests/extension/arrow/bool.py +++ b/pandas/tests/extension/arrow/bool.py @@ -45,7 +45,9 @@ class _ArrowBoolArray(ExtensionArray): def __init__(self, values): if isinstance(values, _ArrowBoolArray): values = values._data - if not isinstance(values, pa.ChunkedArray): + elif isinstance(values, np.ndarray): + values = pa.chunked_array([pa.array(values)]) + elif not isinstance(values, pa.ChunkedArray): raise ValueError assert values.type == pa.bool_() @@ -117,6 +119,7 @@ def copy(self, deep=False): else: return type(self)(copy.copy(self._data)) + @classmethod def _concat_same_type(cls, to_concat): chunks = list(itertools.chain.from_iterable(x._data.chunks for x in to_concat)) diff --git a/pandas/tests/extension/arrow/test_bool.py b/pandas/tests/extension/arrow/test_bool.py index c9799b0bb1291..563ef000a3552 100644 --- a/pandas/tests/extension/arrow/test_bool.py +++ b/pandas/tests/extension/arrow/test_bool.py @@ -137,6 +137,8 @@ class Test2D: def test_getitem_collike(self, data): collike = data.reshape(1, -1) + assert collike[0, 0] is data[0] + result = collike[:, :] assert result.shape == (1, 100) @@ -152,6 +154,8 @@ def test_getitem_collike(self, data): def test_getitem_rowlike(self, data): rowlike = data.reshape(-1, 1) + assert rowlike[0, 0] is data[0] + result = rowlike[:, :] assert result.shape == (100, 1) @@ -163,3 +167,27 @@ def test_getitem_rowlike(self, data): result = rowlike[::2, 0] assert result.shape == (50,) + + def test_take_rowlike(self, data): + rowlike = data.reshape(1, -1) + + result = rowlike.take([0], axis=0) + assert result.shape == (100, 1) + + result2 = rowlike.take([0, 5, 10], axis=1) + assert result2.shape == (1, 3) + assert result2[0, 0] == data[0] + assert result2[0, 1] == data[5] + assert result2[0, 2] == data[10] + + def test_take_collike(self, data): + collike = data.reshape(-1, 1) + + result = collike.take([0], axis=1) + assert result.shape == (1, 100) + + result2 = collike.take([0, 5, 10], axis=0) + assert result2.shape == (3, 1) + assert result2[0, 0] == data[0] + assert result2[1, 0] == data[5] + assert result2[2, 0] == data[10]