Skip to content

Commit

Permalink
REF: IntervalArray comparisons (#37124)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Nov 3, 2020
1 parent 28a0f66 commit 337bf20
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 34 deletions.
79 changes: 62 additions & 17 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
from operator import le, lt
import textwrap
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
Expand All @@ -12,6 +13,7 @@
IntervalMixin,
intervals_to_interval_bounds,
)
from pandas._libs.missing import NA
from pandas._typing import ArrayLike, Dtype
from pandas.compat.numpy import function as nv
from pandas.util._decorators import Appender
Expand Down Expand Up @@ -48,7 +50,7 @@
from pandas.core.construction import array, extract_array
from pandas.core.indexers import check_array_indexer
from pandas.core.indexes.base import ensure_index
from pandas.core.ops import unpack_zerodim_and_defer
from pandas.core.ops import invalid_comparison, unpack_zerodim_and_defer

if TYPE_CHECKING:
from pandas import Index
Expand Down Expand Up @@ -520,16 +522,15 @@ def __setitem__(self, key, value):
self._left[key] = value_left
self._right[key] = value_right

@unpack_zerodim_and_defer("__eq__")
def __eq__(self, other):
def _cmp_method(self, other, op):
# ensure pandas array for list-like and eliminate non-interval scalars
if is_list_like(other):
if len(self) != len(other):
raise ValueError("Lengths must match to compare")
other = array(other)
elif not isinstance(other, Interval):
# non-interval scalar -> no matches
return np.zeros(len(self), dtype=bool)
return invalid_comparison(self, other, op)

# determine the dtype of the elements we want to compare
if isinstance(other, Interval):
Expand All @@ -543,35 +544,79 @@ def __eq__(self, other):
# extract intervals if we have interval categories with matching closed
if is_interval_dtype(other_dtype):
if self.closed != other.categories.closed:
return np.zeros(len(self), dtype=bool)
return invalid_comparison(self, other, op)

other = other.categories.take(
other.codes, allow_fill=True, fill_value=other.categories._na_value
)

# interval-like -> need same closed and matching endpoints
if is_interval_dtype(other_dtype):
if self.closed != other.closed:
return np.zeros(len(self), dtype=bool)
return (self._left == other.left) & (self._right == other.right)
return invalid_comparison(self, other, op)
elif not isinstance(other, Interval):
other = type(self)(other)

if op is operator.eq:
return (self._left == other.left) & (self._right == other.right)
elif op is operator.ne:
return (self._left != other.left) | (self._right != other.right)
elif op is operator.gt:
return (self._left > other.left) | (
(self._left == other.left) & (self._right > other.right)
)
elif op is operator.ge:
return (self == other) | (self > other)
elif op is operator.lt:
return (self._left < other.left) | (
(self._left == other.left) & (self._right < other.right)
)
else:
# operator.lt
return (self == other) | (self < other)

# non-interval/non-object dtype -> no matches
if not is_object_dtype(other_dtype):
return np.zeros(len(self), dtype=bool)
return invalid_comparison(self, other, op)

# object dtype -> iteratively check for intervals
result = np.zeros(len(self), dtype=bool)
for i, obj in enumerate(other):
# need object to be an Interval with same closed and endpoints
if (
isinstance(obj, Interval)
and self.closed == obj.closed
and self._left[i] == obj.left
and self._right[i] == obj.right
):
result[i] = True

try:
result[i] = op(self[i], obj)
except TypeError:
if obj is NA:
# comparison with np.nan returns NA
# github.com/pandas-dev/pandas/pull/37124#discussion_r509095092
result[i] = op is operator.ne
else:
raise
return result

@unpack_zerodim_and_defer("__eq__")
def __eq__(self, other):
return self._cmp_method(other, operator.eq)

@unpack_zerodim_and_defer("__ne__")
def __ne__(self, other):
return self._cmp_method(other, operator.ne)

@unpack_zerodim_and_defer("__gt__")
def __gt__(self, other):
return self._cmp_method(other, operator.gt)

@unpack_zerodim_and_defer("__ge__")
def __ge__(self, other):
return self._cmp_method(other, operator.ge)

@unpack_zerodim_and_defer("__lt__")
def __lt__(self, other):
return self._cmp_method(other, operator.lt)

@unpack_zerodim_and_defer("__le__")
def __le__(self, other):
return self._cmp_method(other, operator.le)

def fillna(self, value=None, method=None, limit=None):
"""
Fill NA/NaN values using the specified method.
Expand Down
13 changes: 0 additions & 13 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,19 +1074,6 @@ def _is_all_dates(self) -> bool:

# TODO: arithmetic operations

# GH#30817 until IntervalArray implements inequalities, get them from Index
def __lt__(self, other):
return Index.__lt__(self, other)

def __le__(self, other):
return Index.__le__(self, other)

def __gt__(self, other):
return Index.__gt__(self, other)

def __ge__(self, other):
return Index.__ge__(self, other)


def _is_valid_endpoint(endpoint) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def test_repeat(self, data, repeats, as_series, use_numpy):
@pytest.mark.parametrize(
"repeats, kwargs, error, msg",
[
(2, dict(axis=1), ValueError, "'axis"),
(2, dict(axis=1), ValueError, "axis"),
(-1, dict(), ValueError, "negative"),
([1, 2], dict(), ValueError, "shape"),
(2, dict(foo="bar"), TypeError, "'foo'"),
Expand Down
8 changes: 5 additions & 3 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,11 @@ def test_comparison(self):
actual = self.index == self.index.left
tm.assert_numpy_array_equal(actual, np.array([False, False]))

msg = (
"not supported between instances of 'int' and "
"'pandas._libs.interval.Interval'"
msg = "|".join(
[
"not supported between instances of 'int' and '.*.Interval'",
r"Invalid comparison between dtype=interval\[int64\] and ",
]
)
with pytest.raises(TypeError, match=msg):
self.index > 0
Expand Down

0 comments on commit 337bf20

Please sign in to comment.