Skip to content

Commit

Permalink
[Bug] Fix various DatetimeIndex comparison bugs (#22074)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Aug 1, 2018
1 parent 57c7daa commit 8d5c51b
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 24 deletions.
5 changes: 5 additions & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,9 @@ Datetimelike
- Fixed bug where two :class:`DateOffset` objects with different ``normalize`` attributes could evaluate as equal (:issue:`21404`)
- Fixed bug where :meth:`Timestamp.resolution` incorrectly returned 1-microsecond ``timedelta`` instead of 1-nanosecond :class:`Timedelta` (:issue:`21336`,:issue:`21365`)
- Bug in :func:`to_datetime` that did not consistently return an :class:`Index` when ``box=True`` was specified (:issue:`21864`)
- Bug in :class:`DatetimeIndex` comparisons where string comparisons incorrectly raises ``TypeError`` (:issue:`22074`)
- Bug in :class:`DatetimeIndex` comparisons when comparing against ``timedelta64[ns]`` dtyped arrays; in some cases ``TypeError`` was incorrectly raised, in others it incorrectly failed to raise (:issue:`22074`)
- Bug in :class:`DatetimeIndex` comparisons when comparing against object-dtyped arrays (:issue:`22074`)

Timedelta
^^^^^^^^^
Expand All @@ -555,6 +558,7 @@ Timezones
- Bug in :class:`Index` with ``datetime64[ns, tz]`` dtype that did not localize integer data correctly (:issue:`20964`)
- Bug in :class:`DatetimeIndex` where constructing with an integer and tz would not localize correctly (:issue:`12619`)
- Fixed bug where :meth:`DataFrame.describe` and :meth:`Series.describe` on tz-aware datetimes did not show `first` and `last` result (:issue:`21328`)
- Bug in :class:`DatetimeIndex` comparisons failing to raise ``TypeError`` when comparing timezone-aware ``DatetimeIndex`` against ``np.datetime64`` (:issue:`22074`)

Offsets
^^^^^^^
Expand All @@ -572,6 +576,7 @@ Numeric
- Bug in :meth:`DataFrame.agg`, :meth:`DataFrame.transform` and :meth:`DataFrame.apply` where,
when supplied with a list of functions and ``axis=1`` (e.g. ``df.apply(['sum', 'mean'], axis=1)``),
a ``TypeError`` was wrongly raised. For all three methods such calculation are now done correctly. (:issue:`16679`).
- Bug in :class:`Series` comparison against datetime-like scalars and arrays (:issue:`22074`)
-

Strings
Expand Down
40 changes: 27 additions & 13 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from pytz import utc

from pandas._libs import tslib
from pandas._libs import lib, tslib
from pandas._libs.tslib import Timestamp, NaT, iNaT
from pandas._libs.tslibs import (
normalize_date,
Expand All @@ -18,7 +18,7 @@

from pandas.core.dtypes.common import (
_NS_DTYPE,
is_datetimelike,
is_object_dtype,
is_datetime64tz_dtype,
is_datetime64_dtype,
is_timedelta64_dtype,
Expand All @@ -29,6 +29,7 @@

import pandas.core.common as com
from pandas.core.algorithms import checked_add_with_arr
from pandas.core import ops

from pandas.tseries.frequencies import to_offset
from pandas.tseries.offsets import Tick, Day, generate_range
Expand Down Expand Up @@ -99,31 +100,40 @@ def wrapper(self, other):
meth = getattr(dtl.DatetimeLikeArrayMixin, opname)

if isinstance(other, (datetime, np.datetime64, compat.string_types)):
if isinstance(other, datetime):
if isinstance(other, (datetime, np.datetime64)):
# GH#18435 strings get a pass from tzawareness compat
self._assert_tzawareness_compat(other)

other = _to_m8(other, tz=self.tz)
try:
other = _to_m8(other, tz=self.tz)
except ValueError:
# string that cannot be parsed to Timestamp
return ops.invalid_comparison(self, other, op)

result = meth(self, other)
if isna(other):
result.fill(nat_result)
elif lib.is_scalar(other):
return ops.invalid_comparison(self, other, op)
else:
if isinstance(other, list):
# FIXME: This can break for object-dtype with mixed types
other = type(self)(other)
elif not isinstance(other, (np.ndarray, ABCIndexClass, ABCSeries)):
# Following Timestamp convention, __eq__ is all-False
# and __ne__ is all True, others raise TypeError.
if opname == '__eq__':
return np.zeros(shape=self.shape, dtype=bool)
elif opname == '__ne__':
return np.ones(shape=self.shape, dtype=bool)
raise TypeError('%s type object %s' %
(type(other), str(other)))

if is_datetimelike(other):
return ops.invalid_comparison(self, other, op)

if is_object_dtype(other):
result = op(self.astype('O'), np.array(other))
elif not (is_datetime64_dtype(other) or
is_datetime64tz_dtype(other)):
# e.g. is_timedelta64_dtype(other)
return ops.invalid_comparison(self, other, op)
else:
self._assert_tzawareness_compat(other)
result = meth(self, np.asarray(other))

result = meth(self, np.asarray(other))
result = com.values_from_object(result)

# Make sure to pass an array to result[...]; indexing with
Expand Down Expand Up @@ -152,6 +162,10 @@ class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin):
'is_year_end', 'is_leap_year']
_object_ops = ['weekday_name', 'freq', 'tz']

# dummy attribute so that datetime.__eq__(DatetimeArray) defers
# by returning NotImplemented
timetuple = None

# -----------------------------------------------------------------
# Constructors

Expand Down
31 changes: 30 additions & 1 deletion pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,35 @@ def mask_cmp_op(x, y, op, allowed_types):
return result


def invalid_comparison(left, right, op):
"""
If a comparison has mismatched types and is not necessarily meaningful,
follow python3 conventions by:
- returning all-False for equality
- returning all-True for inequality
- raising TypeError otherwise
Parameters
----------
left : array-like
right : scalar, array-like
op : operator.{eq, ne, lt, le, gt}
Raises
------
TypeError : on inequality comparisons
"""
if op is operator.eq:
res_values = np.zeros(left.shape, dtype=bool)
elif op is operator.ne:
res_values = np.ones(left.shape, dtype=bool)
else:
raise TypeError("Invalid comparison between dtype={dtype} and {typ}"
.format(dtype=left.dtype, typ=type(right).__name__))
return res_values


# -----------------------------------------------------------------------------
# Functions that add arithmetic methods to objects, given arithmetic factory
# methods
Expand Down Expand Up @@ -1259,7 +1288,7 @@ def na_op(x, y):
result = _comp_method_OBJECT_ARRAY(op, x, y)

elif is_datetimelike_v_numeric(x, y):
raise TypeError("invalid type comparison")
return invalid_comparison(x, y, op)

else:

Expand Down
16 changes: 14 additions & 2 deletions pandas/tests/frame/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,20 @@ def test_comparison_invalid(self):
def check(df, df2):

for (x, y) in [(df, df2), (df2, df)]:
pytest.raises(TypeError, lambda: x == y)
pytest.raises(TypeError, lambda: x != y)
# we expect the result to match Series comparisons for
# == and !=, inequalities should raise
result = x == y
expected = DataFrame({col: x[col] == y[col]
for col in x.columns},
index=x.index, columns=x.columns)
assert_frame_equal(result, expected)

result = x != y
expected = DataFrame({col: x[col] != y[col]
for col in x.columns},
index=x.index, columns=x.columns)
assert_frame_equal(result, expected)

pytest.raises(TypeError, lambda: x >= y)
pytest.raises(TypeError, lambda: x > y)
pytest.raises(TypeError, lambda: x < y)
Expand Down
8 changes: 6 additions & 2 deletions pandas/tests/frame/test_query_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,13 @@ def test_date_query_with_non_date(self):
df = DataFrame({'dates': date_range('1/1/2012', periods=n),
'nondate': np.arange(n)})

ops = '==', '!=', '<', '>', '<=', '>='
result = df.query('dates == nondate', parser=parser, engine=engine)
assert len(result) == 0

for op in ops:
result = df.query('dates != nondate', parser=parser, engine=engine)
assert_frame_equal(result, df)

for op in ['<', '>', '<=', '>=']:
with pytest.raises(TypeError):
df.query('dates %s nondate' % op, parser=parser, engine=engine)

Expand Down
121 changes: 117 additions & 4 deletions pandas/tests/indexes/datetimes/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,20 @@ def test_comparison_tzawareness_compat(self, op):
with pytest.raises(TypeError):
op(ts, dz)

@pytest.mark.parametrize('op', [operator.eq, operator.ne,
operator.gt, operator.ge,
operator.lt, operator.le])
@pytest.mark.parametrize('other', [datetime(2016, 1, 1),
Timestamp('2016-01-01'),
np.datetime64('2016-01-01')])
def test_scalar_comparison_tzawareness(self, op, other, tz_aware_fixture):
tz = tz_aware_fixture
dti = pd.date_range('2016-01-01', periods=2, tz=tz)
with pytest.raises(TypeError):
op(dti, other)
with pytest.raises(TypeError):
op(other, dti)

@pytest.mark.parametrize('op', [operator.eq, operator.ne,
operator.gt, operator.ge,
operator.lt, operator.le])
Expand All @@ -290,12 +304,60 @@ def test_nat_comparison_tzawareness(self, op):
result = op(dti.tz_localize('US/Pacific'), pd.NaT)
tm.assert_numpy_array_equal(result, expected)

def test_dti_cmp_int_raises(self):
rng = date_range('1/1/2000', periods=10)
def test_dti_cmp_str(self, tz_naive_fixture):
# GH#22074
# regardless of tz, we expect these comparisons are valid
tz = tz_naive_fixture
rng = date_range('1/1/2000', periods=10, tz=tz)
other = '1/1/2000'

result = rng == other
expected = np.array([True] + [False] * 9)
tm.assert_numpy_array_equal(result, expected)

result = rng != other
expected = np.array([False] + [True] * 9)
tm.assert_numpy_array_equal(result, expected)

result = rng < other
expected = np.array([False] * 10)
tm.assert_numpy_array_equal(result, expected)

result = rng <= other
expected = np.array([True] + [False] * 9)
tm.assert_numpy_array_equal(result, expected)

result = rng > other
expected = np.array([False] + [True] * 9)
tm.assert_numpy_array_equal(result, expected)

result = rng >= other
expected = np.array([True] * 10)
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize('other', ['foo', 99, 4.0,
object(), timedelta(days=2)])
def test_dti_cmp_scalar_invalid(self, other, tz_naive_fixture):
# GH#22074
tz = tz_naive_fixture
rng = date_range('1/1/2000', periods=10, tz=tz)

result = rng == other
expected = np.array([False] * 10)
tm.assert_numpy_array_equal(result, expected)

result = rng != other
expected = np.array([True] * 10)
tm.assert_numpy_array_equal(result, expected)

# raise TypeError for now
with pytest.raises(TypeError):
rng < rng[3].value
rng < other
with pytest.raises(TypeError):
rng <= other
with pytest.raises(TypeError):
rng > other
with pytest.raises(TypeError):
rng >= other

def test_dti_cmp_list(self):
rng = date_range('1/1/2000', periods=10)
Expand All @@ -304,6 +366,57 @@ def test_dti_cmp_list(self):
expected = rng == rng
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize('other', [
pd.timedelta_range('1D', periods=10),
pd.timedelta_range('1D', periods=10).to_series(),
pd.timedelta_range('1D', periods=10).asi8.view('m8[ns]')
], ids=lambda x: type(x).__name__)
def test_dti_cmp_tdi_tzawareness(self, other):
# GH#22074
# reversion test that we _don't_ call _assert_tzawareness_compat
# when comparing against TimedeltaIndex
dti = date_range('2000-01-01', periods=10, tz='Asia/Tokyo')

result = dti == other
expected = np.array([False] * 10)
tm.assert_numpy_array_equal(result, expected)

result = dti != other
expected = np.array([True] * 10)
tm.assert_numpy_array_equal(result, expected)

with pytest.raises(TypeError):
dti < other
with pytest.raises(TypeError):
dti <= other
with pytest.raises(TypeError):
dti > other
with pytest.raises(TypeError):
dti >= other

def test_dti_cmp_object_dtype(self):
# GH#22074
dti = date_range('2000-01-01', periods=10, tz='Asia/Tokyo')

other = dti.astype('O')

result = dti == other
expected = np.array([True] * 10)
tm.assert_numpy_array_equal(result, expected)

other = dti.tz_localize(None)
with pytest.raises(TypeError):
# tzawareness failure
dti != other

other = np.array(list(dti[:5]) + [Timedelta(days=1)] * 5)
result = dti == other
expected = np.array([True] * 5 + [False] * 5)
tm.assert_numpy_array_equal(result, expected)

with pytest.raises(TypeError):
dti >= other


class TestDatetimeIndexArithmetic(object):

Expand Down
11 changes: 9 additions & 2 deletions pandas/tests/series/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,15 @@ def test_comparison_invalid(self):
s2 = Series(date_range('20010101', periods=5))

for (x, y) in [(s, s2), (s2, s)]:
pytest.raises(TypeError, lambda: x == y)
pytest.raises(TypeError, lambda: x != y)

result = x == y
expected = Series([False] * 5)
assert_series_equal(result, expected)

result = x != y
expected = Series([True] * 5)
assert_series_equal(result, expected)

pytest.raises(TypeError, lambda: x >= y)
pytest.raises(TypeError, lambda: x > y)
pytest.raises(TypeError, lambda: x < y)
Expand Down

0 comments on commit 8d5c51b

Please sign in to comment.