Skip to content

Commit

Permalink
Implement _most_ of the EA interface for DTA/TDA (pandas-dev#23643)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and Pingviinituutti committed Feb 28, 2019
1 parent f6fa535 commit fef01e1
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 77 deletions.
63 changes: 62 additions & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from pandas.core.dtypes.missing import isna

import pandas.core.common as com
from pandas.core.algorithms import checked_add_with_arr
from pandas.core.algorithms import checked_add_with_arr, take, unique1d

from .base import ExtensionOpsMixin
from pandas.util._decorators import deprecate_kwarg
Expand Down Expand Up @@ -196,6 +196,67 @@ def astype(self, dtype, copy=True):
return self._box_values(self.asi8)
return super(DatetimeLikeArrayMixin, self).astype(dtype, copy)

# ------------------------------------------------------------------
# ExtensionArray Interface
# TODO:
# * _from_sequence
# * argsort / _values_for_argsort
# * _reduce

def unique(self):
result = unique1d(self.asi8)
return type(self)(result, dtype=self.dtype)

def _validate_fill_value(self, fill_value):
"""
If a fill_value is passed to `take` convert it to an i8 representation,
raising ValueError if this is not possible.
Parameters
----------
fill_value : object
Returns
-------
fill_value : np.int64
Raises
------
ValueError
"""
raise AbstractMethodError(self)

def take(self, indices, allow_fill=False, fill_value=None):
if allow_fill:
fill_value = self._validate_fill_value(fill_value)

new_values = take(self.asi8,
indices,
allow_fill=allow_fill,
fill_value=fill_value)

return type(self)(new_values, dtype=self.dtype)

@classmethod
def _concat_same_type(cls, to_concat):
dtypes = {x.dtype for x in to_concat}
assert len(dtypes) == 1
dtype = list(dtypes)[0]

values = np.concatenate([x.asi8 for x in to_concat])
return cls(values, dtype=dtype)

def copy(self, deep=False):
values = self.asi8.copy()
return type(self)(values, dtype=self.dtype, freq=self.freq)

def _values_for_factorize(self):
return self.asi8, iNaT

@classmethod
def _from_factorized(cls, values, original):
return cls(values, dtype=original.dtype)

# ------------------------------------------------------------------
# Null Handling

Expand Down
28 changes: 23 additions & 5 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
conversion, fields, timezones,
resolution as libresolution)

from pandas.util._decorators import cache_readonly
from pandas.util._decorators import cache_readonly, Appender
from pandas.errors import PerformanceWarning
from pandas import compat

Expand All @@ -21,8 +21,7 @@
is_object_dtype,
is_int64_dtype,
is_datetime64tz_dtype,
is_datetime64_dtype,
ensure_int64)
is_datetime64_dtype)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
Expand Down Expand Up @@ -294,7 +293,7 @@ def _generate_range(cls, start, end, periods, freq, tz=None,

if tz is not None and index.tz is None:
arr = conversion.tz_localize_to_utc(
ensure_int64(index.values),
index.asi8,
tz, ambiguous=ambiguous)

index = cls(arr)
Expand All @@ -317,7 +316,7 @@ def _generate_range(cls, start, end, periods, freq, tz=None,
if not right_closed and len(index) and index[-1] == end:
index = index[:-1]

return cls._simple_new(index.values, freq=freq, tz=tz)
return cls._simple_new(index.asi8, freq=freq, tz=tz)

# -----------------------------------------------------------------
# Descriptive Properties
Expand Down Expand Up @@ -419,6 +418,25 @@ def __iter__(self):
for v in converted:
yield v

# ----------------------------------------------------------------
# ExtensionArray Interface

@property
def _ndarray_values(self):
return self._data

@Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__)
def _validate_fill_value(self, fill_value):
if isna(fill_value):
fill_value = iNaT
elif isinstance(fill_value, (datetime, np.datetime64)):
self._assert_tzawareness_compat(fill_value)
fill_value = Timestamp(fill_value).value
else:
raise ValueError("'fill_value' should be a Timestamp. "
"Got '{got}'.".format(got=fill_value))
return fill_value

# -----------------------------------------------------------------
# Comparison Methods

Expand Down
56 changes: 14 additions & 42 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,6 @@ def _from_sequence(cls, scalars, dtype=None, copy=False):
ordinals = libperiod.extract_ordinals(periods, freq)
return cls(ordinals, freq=freq)

def _values_for_factorize(self):
return self.asi8, iNaT

@classmethod
def _from_factorized(cls, values, original):
# type: (Sequence[Optional[Period]], PeriodArray) -> PeriodArray
return cls(values, freq=original.freq)

@classmethod
def _from_datetime64(cls, data, freq, tz=None):
"""Construct a PeriodArray from a datetime64 array
Expand Down Expand Up @@ -262,14 +254,6 @@ def _generate_range(cls, start, end, periods, freq, fields):

return subarr, freq

@classmethod
def _concat_same_type(cls, to_concat):
freq = {x.freq for x in to_concat}
assert len(freq) == 1
freq = list(freq)[0]
values = np.concatenate([x._data for x in to_concat])
return cls(values, freq=freq)

# --------------------------------------------------------------------
# Data / Attributes

Expand Down Expand Up @@ -415,29 +399,20 @@ def __setitem__(
raise TypeError(msg)
self._data[key] = value

def take(self, indices, allow_fill=False, fill_value=None):
if allow_fill:
if isna(fill_value):
fill_value = iNaT
elif isinstance(fill_value, Period):
if self.freq != fill_value.freq:
msg = DIFFERENT_FREQ_INDEX.format(
self.freq.freqstr,
fill_value.freqstr
)
raise IncompatibleFrequency(msg)

fill_value = fill_value.ordinal
else:
msg = "'fill_value' should be a Period. Got '{}'."
raise ValueError(msg.format(fill_value))

new_values = algos.take(self._data,
indices,
allow_fill=allow_fill,
fill_value=fill_value)

return type(self)(new_values, self.freq)
@Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__)
def _validate_fill_value(self, fill_value):
if isna(fill_value):
fill_value = iNaT
elif isinstance(fill_value, Period):
if fill_value.freq != self.freq:
msg = DIFFERENT_FREQ_INDEX.format(self.freq.freqstr,
fill_value.freqstr)
raise IncompatibleFrequency(msg)
fill_value = fill_value.ordinal
else:
raise ValueError("'fill_value' should be a Period. "
"Got '{got}'.".format(got=fill_value))
return fill_value

def fillna(self, value=None, method=None, limit=None):
# TODO(#20300)
Expand Down Expand Up @@ -474,9 +449,6 @@ def fillna(self, value=None, method=None, limit=None):
new_values = self.copy()
return new_values

def copy(self, deep=False):
return type(self)(self._data.copy(), freq=self.freq)

def value_counts(self, dropna=False):
from pandas import Series, PeriodIndex

Expand Down
14 changes: 13 additions & 1 deletion pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas._libs.tslibs.fields import get_timedelta_field
from pandas._libs.tslibs.timedeltas import (
array_to_timedelta64, parse_timedelta_unit)
from pandas.util._decorators import Appender

from pandas import compat

Expand Down Expand Up @@ -139,7 +140,7 @@ def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE):
result._freq = freq
return result

def __new__(cls, values, freq=None):
def __new__(cls, values, freq=None, dtype=_TD_DTYPE):

freq, freq_infer = dtl.maybe_infer_freq(freq)

Expand Down Expand Up @@ -193,6 +194,17 @@ def _generate_range(cls, start, end, periods, freq, closed=None):
# ----------------------------------------------------------------
# Array-Like / EA-Interface Methods

@Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__)
def _validate_fill_value(self, fill_value):
if isna(fill_value):
fill_value = iNaT
elif isinstance(fill_value, (timedelta, np.timedelta64, Tick)):
fill_value = Timedelta(fill_value).value
else:
raise ValueError("'fill_value' should be a Timedelta. "
"Got '{got}'.".format(got=fill_value))
return fill_value

# ----------------------------------------------------------------
# Arithmetic Methods

Expand Down
8 changes: 1 addition & 7 deletions pandas/core/dtypes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,7 @@ def _concat_datetimetz(to_concat, name=None):
all inputs must be DatetimeIndex
it is used in DatetimeIndex.append also
"""
# do not pass tz to set because tzlocal cannot be hashed
if len({str(x.dtype) for x in to_concat}) != 1:
raise ValueError('to_concat must have the same tz')
tz = to_concat[0].tz
# no need to localize because internal repr will not be changed
new_values = np.concatenate([x.asi8 for x in to_concat])
return to_concat[0]._simple_new(new_values, tz=tz, name=name)
return to_concat[0]._concat_same_dtype(to_concat, name=name)


def _concat_index_same_dtype(indexes, klass=None):
Expand Down
23 changes: 15 additions & 8 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
is_datetime_or_timedelta_dtype, is_dtype_equal, is_float, is_float_dtype,
is_integer, is_integer_dtype, is_list_like, is_object_dtype,
is_period_dtype, is_scalar, is_string_dtype)
import pandas.core.dtypes.concat as _concat
from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries
from pandas.core.dtypes.missing import isna

Expand Down Expand Up @@ -215,6 +214,11 @@ def ceil(self, freq, ambiguous='raise', nonexistent='raise'):
class DatetimeIndexOpsMixin(DatetimeLikeArrayMixin):
""" common ops mixin to support a unified interface datetimelike Index """

# override DatetimeLikeArrayMixin method
copy = Index.copy
unique = Index.unique
take = Index.take

# DatetimeLikeArrayMixin assumes subclasses are mutable, so these are
# properties there. They can be made into cache_readonly for Index
# subclasses bc they are immutable
Expand Down Expand Up @@ -685,17 +689,21 @@ def _concat_same_dtype(self, to_concat, name):
"""
attribs = self._get_attributes_dict()
attribs['name'] = name
# do not pass tz to set because tzlocal cannot be hashed
if len({str(x.dtype) for x in to_concat}) != 1:
raise ValueError('to_concat must have the same tz')

if not is_period_dtype(self):
# reset freq
attribs['freq'] = None

if getattr(self, 'tz', None) is not None:
return _concat._concat_datetimetz(to_concat, name)
# TODO(DatetimeArray)
# - remove the .asi8 here
# - remove the _maybe_box_as_values
# - combine with the `else` block
new_data = self._concat_same_type(to_concat).asi8
else:
new_data = np.concatenate([c.asi8 for c in to_concat])
new_data = type(self._values)._concat_same_type(to_concat)

new_data = self._maybe_box_as_values(new_data, **attribs)
return self._simple_new(new_data, **attribs)

def _maybe_box_as_values(self, values, **attribs):
Expand All @@ -704,7 +712,6 @@ def _maybe_box_as_values(self, values, **attribs):
# but others are not. When everyone is an ExtensionArray, this can
# be removed. Currently used in
# - sort_values
# - _concat_same_dtype
return values

def astype(self, dtype, copy=True):
Expand Down Expand Up @@ -761,7 +768,7 @@ def _ensure_datetimelike_to_i8(other, to_utc=False):
try:
return np.array(other, copy=False).view('i8')
except TypeError:
# period array cannot be coerces to int
# period array cannot be coerced to int
other = Index(other)
return other.asi8

Expand Down
17 changes: 7 additions & 10 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,16 +551,13 @@ def snap(self, freq='S'):
# TODO: what about self.name? if so, use shallow_copy?

def unique(self, level=None):
# Override here since IndexOpsMixin.unique uses self._values.unique
# For DatetimeIndex with TZ, that's a DatetimeIndex -> recursion error
# So we extract the tz-naive DatetimeIndex, unique that, and wrap the
# result with out TZ.
if self.tz is not None:
naive = type(self)(self._ndarray_values, copy=False)
else:
naive = self
result = super(DatetimeIndex, naive).unique(level=level)
return self._shallow_copy(result.values)
if level is not None:
self._validate_index_level(level)

# TODO(DatetimeArray): change dispatch once inheritance is removed
# call DatetimeArray method
result = DatetimeArray.unique(self)
return self._shallow_copy(result._data)

def union(self, other):
"""
Expand Down
Loading

0 comments on commit fef01e1

Please sign in to comment.