From fef01e1438b37dd454ef710decd31c536da4c1f0 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 14 Nov 2018 07:23:02 -0800 Subject: [PATCH] Implement _most_ of the EA interface for DTA/TDA (#23643) --- pandas/core/arrays/datetimelike.py | 63 ++++++++++- pandas/core/arrays/datetimes.py | 28 ++++- pandas/core/arrays/period.py | 56 +++------ pandas/core/arrays/timedeltas.py | 14 ++- pandas/core/dtypes/concat.py | 8 +- pandas/core/indexes/datetimelike.py | 23 ++-- pandas/core/indexes/datetimes.py | 17 ++- pandas/tests/arrays/test_datetimelike.py | 138 ++++++++++++++++++++++- 8 files changed, 270 insertions(+), 77 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index daf2dcccd284ba..094c9c3df0bed1 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -39,7 +39,7 @@ from pandas.core.dtypes.missing import isna import pandas.core.common as com -from pandas.core.algorithms import checked_add_with_arr +from pandas.core.algorithms import checked_add_with_arr, take, unique1d from .base import ExtensionOpsMixin from pandas.util._decorators import deprecate_kwarg @@ -196,6 +196,67 @@ def astype(self, dtype, copy=True): return self._box_values(self.asi8) return super(DatetimeLikeArrayMixin, self).astype(dtype, copy) + # ------------------------------------------------------------------ + # ExtensionArray Interface + # TODO: + # * _from_sequence + # * argsort / _values_for_argsort + # * _reduce + + def unique(self): + result = unique1d(self.asi8) + return type(self)(result, dtype=self.dtype) + + def _validate_fill_value(self, fill_value): + """ + If a fill_value is passed to `take` convert it to an i8 representation, + raising ValueError if this is not possible. + + Parameters + ---------- + fill_value : object + + Returns + ------- + fill_value : np.int64 + + Raises + ------ + ValueError + """ + raise AbstractMethodError(self) + + def take(self, indices, allow_fill=False, fill_value=None): + if allow_fill: + fill_value = self._validate_fill_value(fill_value) + + new_values = take(self.asi8, + indices, + allow_fill=allow_fill, + fill_value=fill_value) + + return type(self)(new_values, dtype=self.dtype) + + @classmethod + def _concat_same_type(cls, to_concat): + dtypes = {x.dtype for x in to_concat} + assert len(dtypes) == 1 + dtype = list(dtypes)[0] + + values = np.concatenate([x.asi8 for x in to_concat]) + return cls(values, dtype=dtype) + + def copy(self, deep=False): + values = self.asi8.copy() + return type(self)(values, dtype=self.dtype, freq=self.freq) + + def _values_for_factorize(self): + return self.asi8, iNaT + + @classmethod + def _from_factorized(cls, values, original): + return cls(values, dtype=original.dtype) + # ------------------------------------------------------------------ # Null Handling diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 926228f267049a..7b4e362ac9fa0e 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -12,7 +12,7 @@ conversion, fields, timezones, resolution as libresolution) -from pandas.util._decorators import cache_readonly +from pandas.util._decorators import cache_readonly, Appender from pandas.errors import PerformanceWarning from pandas import compat @@ -21,8 +21,7 @@ is_object_dtype, is_int64_dtype, is_datetime64tz_dtype, - is_datetime64_dtype, - ensure_int64) + is_datetime64_dtype) from pandas.core.dtypes.dtypes import DatetimeTZDtype from pandas.core.dtypes.missing import isna from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries @@ -294,7 +293,7 @@ def _generate_range(cls, start, end, periods, freq, tz=None, if tz is not None and index.tz is None: arr = conversion.tz_localize_to_utc( - ensure_int64(index.values), + index.asi8, tz, ambiguous=ambiguous) index = cls(arr) @@ -317,7 +316,7 @@ def _generate_range(cls, start, end, periods, freq, tz=None, if not right_closed and len(index) and index[-1] == end: index = index[:-1] - return cls._simple_new(index.values, freq=freq, tz=tz) + return cls._simple_new(index.asi8, freq=freq, tz=tz) # ----------------------------------------------------------------- # Descriptive Properties @@ -419,6 +418,25 @@ def __iter__(self): for v in converted: yield v + # ---------------------------------------------------------------- + # ExtensionArray Interface + + @property + def _ndarray_values(self): + return self._data + + @Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__) + def _validate_fill_value(self, fill_value): + if isna(fill_value): + fill_value = iNaT + elif isinstance(fill_value, (datetime, np.datetime64)): + self._assert_tzawareness_compat(fill_value) + fill_value = Timestamp(fill_value).value + else: + raise ValueError("'fill_value' should be a Timestamp. " + "Got '{got}'.".format(got=fill_value)) + return fill_value + # ----------------------------------------------------------------- # Comparison Methods diff --git a/pandas/core/arrays/period.py b/pandas/core/arrays/period.py index faba404faeb239..e46b00da6161ee 100644 --- a/pandas/core/arrays/period.py +++ b/pandas/core/arrays/period.py @@ -216,14 +216,6 @@ def _from_sequence(cls, scalars, dtype=None, copy=False): ordinals = libperiod.extract_ordinals(periods, freq) return cls(ordinals, freq=freq) - def _values_for_factorize(self): - return self.asi8, iNaT - - @classmethod - def _from_factorized(cls, values, original): - # type: (Sequence[Optional[Period]], PeriodArray) -> PeriodArray - return cls(values, freq=original.freq) - @classmethod def _from_datetime64(cls, data, freq, tz=None): """Construct a PeriodArray from a datetime64 array @@ -262,14 +254,6 @@ def _generate_range(cls, start, end, periods, freq, fields): return subarr, freq - @classmethod - def _concat_same_type(cls, to_concat): - freq = {x.freq for x in to_concat} - assert len(freq) == 1 - freq = list(freq)[0] - values = np.concatenate([x._data for x in to_concat]) - return cls(values, freq=freq) - # -------------------------------------------------------------------- # Data / Attributes @@ -415,29 +399,20 @@ def __setitem__( raise TypeError(msg) self._data[key] = value - def take(self, indices, allow_fill=False, fill_value=None): - if allow_fill: - if isna(fill_value): - fill_value = iNaT - elif isinstance(fill_value, Period): - if self.freq != fill_value.freq: - msg = DIFFERENT_FREQ_INDEX.format( - self.freq.freqstr, - fill_value.freqstr - ) - raise IncompatibleFrequency(msg) - - fill_value = fill_value.ordinal - else: - msg = "'fill_value' should be a Period. Got '{}'." - raise ValueError(msg.format(fill_value)) - - new_values = algos.take(self._data, - indices, - allow_fill=allow_fill, - fill_value=fill_value) - - return type(self)(new_values, self.freq) + @Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__) + def _validate_fill_value(self, fill_value): + if isna(fill_value): + fill_value = iNaT + elif isinstance(fill_value, Period): + if fill_value.freq != self.freq: + msg = DIFFERENT_FREQ_INDEX.format(self.freq.freqstr, + fill_value.freqstr) + raise IncompatibleFrequency(msg) + fill_value = fill_value.ordinal + else: + raise ValueError("'fill_value' should be a Period. " + "Got '{got}'.".format(got=fill_value)) + return fill_value def fillna(self, value=None, method=None, limit=None): # TODO(#20300) @@ -474,9 +449,6 @@ def fillna(self, value=None, method=None, limit=None): new_values = self.copy() return new_values - def copy(self, deep=False): - return type(self)(self._data.copy(), freq=self.freq) - def value_counts(self, dropna=False): from pandas import Series, PeriodIndex diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 9dbdd6ff8b5620..ad564ca34930fd 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -9,6 +9,7 @@ from pandas._libs.tslibs.fields import get_timedelta_field from pandas._libs.tslibs.timedeltas import ( array_to_timedelta64, parse_timedelta_unit) +from pandas.util._decorators import Appender from pandas import compat @@ -139,7 +140,7 @@ def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE): result._freq = freq return result - def __new__(cls, values, freq=None): + def __new__(cls, values, freq=None, dtype=_TD_DTYPE): freq, freq_infer = dtl.maybe_infer_freq(freq) @@ -193,6 +194,17 @@ def _generate_range(cls, start, end, periods, freq, closed=None): # ---------------------------------------------------------------- # Array-Like / EA-Interface Methods + @Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__) + def _validate_fill_value(self, fill_value): + if isna(fill_value): + fill_value = iNaT + elif isinstance(fill_value, (timedelta, np.timedelta64, Tick)): + fill_value = Timedelta(fill_value).value + else: + raise ValueError("'fill_value' should be a Timedelta. " + "Got '{got}'.".format(got=fill_value)) + return fill_value + # ---------------------------------------------------------------- # Arithmetic Methods diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index bb4ab823069eeb..ebfb41825ae0a7 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -476,13 +476,7 @@ def _concat_datetimetz(to_concat, name=None): all inputs must be DatetimeIndex it is used in DatetimeIndex.append also """ - # do not pass tz to set because tzlocal cannot be hashed - if len({str(x.dtype) for x in to_concat}) != 1: - raise ValueError('to_concat must have the same tz') - tz = to_concat[0].tz - # no need to localize because internal repr will not be changed - new_values = np.concatenate([x.asi8 for x in to_concat]) - return to_concat[0]._simple_new(new_values, tz=tz, name=name) + return to_concat[0]._concat_same_dtype(to_concat, name=name) def _concat_index_same_dtype(indexes, klass=None): diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index 3f9a60f6d5c51c..39bc7f4b85de27 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -18,7 +18,6 @@ is_datetime_or_timedelta_dtype, is_dtype_equal, is_float, is_float_dtype, is_integer, is_integer_dtype, is_list_like, is_object_dtype, is_period_dtype, is_scalar, is_string_dtype) -import pandas.core.dtypes.concat as _concat from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries from pandas.core.dtypes.missing import isna @@ -215,6 +214,11 @@ def ceil(self, freq, ambiguous='raise', nonexistent='raise'): class DatetimeIndexOpsMixin(DatetimeLikeArrayMixin): """ common ops mixin to support a unified interface datetimelike Index """ + # override DatetimeLikeArrayMixin method + copy = Index.copy + unique = Index.unique + take = Index.take + # DatetimeLikeArrayMixin assumes subclasses are mutable, so these are # properties there. They can be made into cache_readonly for Index # subclasses bc they are immutable @@ -685,17 +689,21 @@ def _concat_same_dtype(self, to_concat, name): """ attribs = self._get_attributes_dict() attribs['name'] = name + # do not pass tz to set because tzlocal cannot be hashed + if len({str(x.dtype) for x in to_concat}) != 1: + raise ValueError('to_concat must have the same tz') if not is_period_dtype(self): # reset freq attribs['freq'] = None - - if getattr(self, 'tz', None) is not None: - return _concat._concat_datetimetz(to_concat, name) + # TODO(DatetimeArray) + # - remove the .asi8 here + # - remove the _maybe_box_as_values + # - combine with the `else` block + new_data = self._concat_same_type(to_concat).asi8 else: - new_data = np.concatenate([c.asi8 for c in to_concat]) + new_data = type(self._values)._concat_same_type(to_concat) - new_data = self._maybe_box_as_values(new_data, **attribs) return self._simple_new(new_data, **attribs) def _maybe_box_as_values(self, values, **attribs): @@ -704,7 +712,6 @@ def _maybe_box_as_values(self, values, **attribs): # but others are not. When everyone is an ExtensionArray, this can # be removed. Currently used in # - sort_values - # - _concat_same_dtype return values def astype(self, dtype, copy=True): @@ -761,7 +768,7 @@ def _ensure_datetimelike_to_i8(other, to_utc=False): try: return np.array(other, copy=False).view('i8') except TypeError: - # period array cannot be coerces to int + # period array cannot be coerced to int other = Index(other) return other.asi8 diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index b754b2705d034f..23446a57e7789f 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -551,16 +551,13 @@ def snap(self, freq='S'): # TODO: what about self.name? if so, use shallow_copy? def unique(self, level=None): - # Override here since IndexOpsMixin.unique uses self._values.unique - # For DatetimeIndex with TZ, that's a DatetimeIndex -> recursion error - # So we extract the tz-naive DatetimeIndex, unique that, and wrap the - # result with out TZ. - if self.tz is not None: - naive = type(self)(self._ndarray_values, copy=False) - else: - naive = self - result = super(DatetimeIndex, naive).unique(level=level) - return self._shallow_copy(result.values) + if level is not None: + self._validate_index_level(level) + + # TODO(DatetimeArray): change dispatch once inheritance is removed + # call DatetimeArray method + result = DatetimeArray.unique(self) + return self._shallow_copy(result._data) def union(self, other): """ diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index bb4022c9cac9af..a1242e2481fed5 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -56,7 +56,68 @@ def timedelta_index(request): return pd.TimedeltaIndex(['1 Day', '3 Hours', 'NaT']) -class TestDatetimeArray(object): +class SharedTests(object): + index_cls = None + + def test_take(self): + data = np.arange(100, dtype='i8') + np.random.shuffle(data) + + idx = self.index_cls._simple_new(data, freq='D') + arr = self.array_cls(idx) + + takers = [1, 4, 94] + result = arr.take(takers) + expected = idx.take(takers) + + tm.assert_index_equal(self.index_cls(result), expected) + + takers = np.array([1, 4, 94]) + result = arr.take(takers) + expected = idx.take(takers) + + tm.assert_index_equal(self.index_cls(result), expected) + + def test_take_fill(self): + data = np.arange(10, dtype='i8') + + idx = self.index_cls._simple_new(data, freq='D') + arr = self.array_cls(idx) + + result = arr.take([-1, 1], allow_fill=True, fill_value=None) + assert result[0] is pd.NaT + + result = arr.take([-1, 1], allow_fill=True, fill_value=np.nan) + assert result[0] is pd.NaT + + result = arr.take([-1, 1], allow_fill=True, fill_value=pd.NaT) + assert result[0] is pd.NaT + + with pytest.raises(ValueError): + arr.take([0, 1], allow_fill=True, fill_value=2) + + with pytest.raises(ValueError): + arr.take([0, 1], allow_fill=True, fill_value=2.0) + + with pytest.raises(ValueError): + arr.take([0, 1], allow_fill=True, + fill_value=pd.Timestamp.now().time) + + def test_concat_same_type(self): + data = np.arange(10, dtype='i8') + + idx = self.index_cls._simple_new(data, freq='D').insert(0, pd.NaT) + arr = self.array_cls(idx) + + result = arr._concat_same_type([arr[:-1], arr[1:], arr]) + expected = idx._concat_same_dtype([idx[:-1], idx[1:], idx], None) + + tm.assert_index_equal(self.index_cls(result), expected) + + +class TestDatetimeArray(SharedTests): + index_cls = pd.DatetimeIndex + array_cls = DatetimeArray def test_array_object_dtype(self, tz_naive_fixture): # GH#23524 @@ -175,8 +236,60 @@ def test_int_properties(self, datetime_index, propname): tm.assert_numpy_array_equal(result, expected) + def test_take_fill_valid(self, datetime_index, tz_naive_fixture): + dti = datetime_index.tz_localize(tz_naive_fixture) + arr = DatetimeArray(dti) + + now = pd.Timestamp.now().tz_localize(dti.tz) + result = arr.take([-1, 1], allow_fill=True, fill_value=now) + assert result[0] == now + + with pytest.raises(ValueError): + # fill_value Timedelta invalid + arr.take([-1, 1], allow_fill=True, fill_value=now - now) + + with pytest.raises(ValueError): + # fill_value Period invalid + arr.take([-1, 1], allow_fill=True, fill_value=pd.Period('2014Q1')) + + tz = None if dti.tz is not None else 'US/Eastern' + now = pd.Timestamp.now().tz_localize(tz) + with pytest.raises(TypeError): + # Timestamp with mismatched tz-awareness + arr.take([-1, 1], allow_fill=True, fill_value=now) + + def test_concat_same_type_invalid(self, datetime_index): + # different timezones + dti = datetime_index + arr = DatetimeArray(dti) + + if arr.tz is None: + other = arr.tz_localize('UTC') + else: + other = arr.tz_localize(None) + + with pytest.raises(AssertionError): + arr._concat_same_type([arr, other]) + + def test_concat_same_type_different_freq(self): + # we *can* concatentate DTI with different freqs. + a = DatetimeArray(pd.date_range('2000', periods=2, freq='D', + tz='US/Central')) + b = DatetimeArray(pd.date_range('2000', periods=2, freq='H', + tz='US/Central')) + result = DatetimeArray._concat_same_type([a, b]) + expected = DatetimeArray(pd.to_datetime([ + '2000-01-01 00:00:00', '2000-01-02 00:00:00', + '2000-01-01 00:00:00', '2000-01-01 01:00:00', + ]).tz_localize("US/Central")) + + tm.assert_datetime_array_equal(result, expected) + + +class TestTimedeltaArray(SharedTests): + index_cls = pd.TimedeltaIndex + array_cls = TimedeltaArray -class TestTimedeltaArray(object): def test_from_tdi(self): tdi = pd.TimedeltaIndex(['1 Day', '3 Hours']) arr = TimedeltaArray(tdi) @@ -223,8 +336,27 @@ def test_int_properties(self, timedelta_index, propname): tm.assert_numpy_array_equal(result, expected) + def test_take_fill_valid(self, timedelta_index): + tdi = timedelta_index + arr = TimedeltaArray(tdi) + + td1 = pd.Timedelta(days=1) + result = arr.take([-1, 1], allow_fill=True, fill_value=td1) + assert result[0] == td1 + + now = pd.Timestamp.now() + with pytest.raises(ValueError): + # fill_value Timestamp invalid + arr.take([0, 1], allow_fill=True, fill_value=now) + + with pytest.raises(ValueError): + # fill_value Period invalid + arr.take([0, 1], allow_fill=True, fill_value=now.to_period('D')) + -class TestPeriodArray(object): +class TestPeriodArray(SharedTests): + index_cls = pd.PeriodIndex + array_cls = PeriodArray def test_from_pi(self, period_index): pi = period_index