Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP/REF: define comparison methods non-dynamically #36930

Merged
merged 9 commits into from
Oct 7, 2020
43 changes: 43 additions & 0 deletions pandas/core/arraylike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Methods that can be shared by many array-like classes or subclasses:
Series
Index
ExtensionArray
"""
import operator

from pandas.errors import AbstractMethodError

from pandas.core.ops.common import unpack_zerodim_and_defer


class OpsMixin:
# -------------------------------------------------------------
# Comparisons

def _cmp_method(self, other, op):
raise AbstractMethodError(self)

@unpack_zerodim_and_defer("__eq__")
def __eq__(self, other):
return self._cmp_method(other, operator.eq)

@unpack_zerodim_and_defer("__ne__")
def __ne__(self, other):
return self._cmp_method(other, operator.ne)

@unpack_zerodim_and_defer("__lt__")
def __lt__(self, other):
return self._cmp_method(other, operator.lt)

@unpack_zerodim_and_defer("__le__")
def __le__(self, other):
return self._cmp_method(other, operator.le)

@unpack_zerodim_and_defer("__gt__")
def __gt__(self, other):
return self._cmp_method(other, operator.gt)

@unpack_zerodim_and_defer("__ge__")
def __ge__(self, other):
return self._cmp_method(other, operator.ge)
79 changes: 31 additions & 48 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
round_nsint64,
)
from pandas._typing import DatetimeLikeScalar, DtypeObj
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
from pandas.util._decorators import Appender, Substitution, cache_readonly
Expand All @@ -51,8 +50,8 @@

from pandas.core import nanops, ops
from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.arrays.base import ExtensionOpsMixin
import pandas.core.common as com
from pandas.core.construction import array, extract_array
from pandas.core.indexers import check_array_indexer, check_setitem_lengths
Expand All @@ -73,46 +72,6 @@ class InvalidComparison(Exception):
pass


def _datetimelike_array_cmp(cls, op):
"""
Wrap comparison operations to convert Timestamp/Timedelta/Period-like to
boxed scalars/arrays.
"""
opname = f"__{op.__name__}__"
nat_result = opname == "__ne__"

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
# TODO: handle 2D-like listlikes
return op(self.ravel(), other.ravel()).reshape(self.shape)

try:
other = self._validate_comparison_value(other, opname)
except InvalidComparison:
return invalid_comparison(self, other, op)

dtype = getattr(other, "dtype", None)
if is_object_dtype(dtype):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
return result

other_i8 = self._unbox(other)
result = op(self.asi8, other_i8)

o_mask = isna(other)
if self._hasnans | np.any(o_mask):
result[self._isnan | o_mask] = nat_result

return result

return set_function_name(wrapper, opname, cls)


class AttributesMixin:
_data: np.ndarray

Expand Down Expand Up @@ -426,9 +385,7 @@ def _with_freq(self, freq):
DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin")


class DatetimeLikeArrayMixin(
ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray
):
class DatetimeLikeArrayMixin(OpsMixin, AttributesMixin, NDArrayBackedExtensionArray):
"""
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray

Expand Down Expand Up @@ -1093,7 +1050,35 @@ def _is_unique(self):

# ------------------------------------------------------------------
# Arithmetic Methods
_create_comparison_method = classmethod(_datetimelike_array_cmp)

def _cmp_method(self, other, op):
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
# TODO: handle 2D-like listlikes
return op(self.ravel(), other.ravel()).reshape(self.shape)

try:
other = self._validate_comparison_value(other, f"__{op.__name__}__")
except InvalidComparison:
return invalid_comparison(self, other, op)

dtype = getattr(other, "dtype", None)
if is_object_dtype(dtype):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
return result

other_i8 = self._unbox(other)
result = op(self.asi8, other_i8)

o_mask = isna(other)
if self._hasnans | np.any(o_mask):
nat_result = op is operator.ne
result[self._isnan | o_mask] = nat_result

return result

# pow is invalid for all three subclasses; TimedeltaArray will override
# the multiplication and division ops
Expand Down Expand Up @@ -1560,8 +1545,6 @@ def mean(self, skipna=True):
return self._box_func(result)


DatetimeLikeArrayMixin._add_comparison_ops()

# -------------------------------------------------------------------
# Shared Constructor Helpers

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pandas.core import algorithms, common as com
from pandas.core.accessor import DirNamesMixin
from pandas.core.algorithms import duplicated, unique1d, value_counts
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import ExtensionArray
from pandas.core.construction import create_series_with_explicit_dtype
import pandas.core.nanops as nanops
Expand Down Expand Up @@ -587,7 +588,7 @@ def _is_builtin_func(self, arg):
return self._builtin_table.get(arg, arg)


class IndexOpsMixin:
class IndexOpsMixin(OpsMixin):
"""
Common ops mixin to support a unified interface / docs for Series / Index
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,7 @@ def _drop_labels_or_levels(self, keys, axis: int = 0):
# ----------------------------------------------------------------------
# Iteration

def __hash__(self):
def __hash__(self) -> int:
raise TypeError(
f"{repr(type(self).__name__)} objects are mutable, "
f"thus they cannot be hashed"
Expand Down
75 changes: 30 additions & 45 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,41 +119,6 @@
str_t = str


def _make_comparison_op(op, cls):
def cmp_method(self, other):
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
if other.ndim > 0 and len(self) != len(other):
raise ValueError("Lengths must match to compare")

if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
left = type(other)(self._values, dtype=other.dtype)
return op(left, other)
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
# e.g. PeriodArray
with np.errstate(all="ignore"):
result = op(self._values, other)

elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
# don't pass MultiIndex
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)

elif is_interval_dtype(self.dtype):
with np.errstate(all="ignore"):
result = op(self._values, np.asarray(other))

else:
with np.errstate(all="ignore"):
result = ops.comparison_op(self._values, np.asarray(other), op)

if is_bool_dtype(result):
return result
return ops.invalid_comparison(self, other, op)

name = f"__{op.__name__}__"
return set_function_name(cmp_method, name, cls)


def _make_arithmetic_op(op, cls):
def index_arithmetic_method(self, other):
if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)):
Expand Down Expand Up @@ -5395,17 +5360,38 @@ def drop(self, labels, errors: str_t = "raise"):
# --------------------------------------------------------------------
# Generated Arithmetic, Comparison, and Unary Methods

@classmethod
def _add_comparison_methods(cls):
def _cmp_method(self, other, op):
"""
Add in comparison methods.
Wrapper used to dispatch comparison operations.
"""
cls.__eq__ = _make_comparison_op(operator.eq, cls)
cls.__ne__ = _make_comparison_op(operator.ne, cls)
cls.__lt__ = _make_comparison_op(operator.lt, cls)
cls.__gt__ = _make_comparison_op(operator.gt, cls)
cls.__le__ = _make_comparison_op(operator.le, cls)
cls.__ge__ = _make_comparison_op(operator.ge, cls)
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
if other.ndim > 0 and len(self) != len(other):
raise ValueError("Lengths must match to compare")

if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
left = type(other)(self._values, dtype=other.dtype)
return op(left, other)
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
# e.g. PeriodArray
with np.errstate(all="ignore"):
result = op(self._values, other)

elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
# don't pass MultiIndex
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)

elif is_interval_dtype(self.dtype):
with np.errstate(all="ignore"):
result = op(self._values, np.asarray(other))

else:
with np.errstate(all="ignore"):
result = ops.comparison_op(self._values, np.asarray(other), op)

if is_bool_dtype(result):
return result
return ops.invalid_comparison(self, other, op)

@classmethod
def _add_numeric_methods_binary(cls):
Expand Down Expand Up @@ -5589,7 +5575,6 @@ def shape(self):

Index._add_numeric_methods()
Index._add_logical_methods()
Index._add_comparison_methods()


def ensure_index_from_sequences(sequences, names=None):
Expand Down
31 changes: 2 additions & 29 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@

from pandas.core import algorithms
from pandas.core.construction import extract_array
from pandas.core.ops.array_ops import (
from pandas.core.ops.array_ops import ( # noqa:F401
arithmetic_op,
comp_method_OBJECT_ARRAY,
comparison_op,
get_array_op,
logical_op,
)
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.ops.docstrings import (
_arith_doc_FRAME,
Expand Down Expand Up @@ -324,33 +324,6 @@ def wrapper(left, right):
return wrapper


def comp_method_SERIES(cls, op, special):
"""
Wrapper function for Series arithmetic operations, to avoid
code duplication.
"""
assert special # non-special uses flex_method_SERIES
op_name = _get_op_name(op, special)

@unpack_zerodim_and_defer(op_name)
def wrapper(self, other):

res_name = get_op_result_name(self, other)

if isinstance(other, ABCSeries) and not self._indexed_same(other):
raise ValueError("Can only compare identically-labeled Series objects")

lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

res_values = comparison_op(lvalues, rvalues, op)

return self._construct_result(res_values, name=res_name)

wrapper.__name__ = op_name
return wrapper


def bool_method_SERIES(cls, op, special):
"""
Wrapper function for Series arithmetic operations, to avoid
Expand Down
23 changes: 12 additions & 11 deletions pandas/core/ops/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _get_method_wrappers(cls):
arith_method_SERIES,
bool_method_SERIES,
comp_method_FRAME,
comp_method_SERIES,
flex_comp_method_FRAME,
flex_method_SERIES,
)
Expand All @@ -58,7 +57,7 @@ def _get_method_wrappers(cls):
arith_flex = flex_method_SERIES
comp_flex = flex_method_SERIES
arith_special = arith_method_SERIES
comp_special = comp_method_SERIES
comp_special = None
bool_special = bool_method_SERIES
elif issubclass(cls, ABCDataFrame):
arith_flex = arith_method_FRAME
Expand Down Expand Up @@ -189,16 +188,18 @@ def _create_methods(cls, arith_method, comp_method, bool_method, special):
new_methods["divmod"] = arith_method(cls, divmod, special)
new_methods["rdivmod"] = arith_method(cls, rdivmod, special)

new_methods.update(
dict(
eq=comp_method(cls, operator.eq, special),
ne=comp_method(cls, operator.ne, special),
lt=comp_method(cls, operator.lt, special),
gt=comp_method(cls, operator.gt, special),
le=comp_method(cls, operator.le, special),
ge=comp_method(cls, operator.ge, special),
if comp_method is not None:
# Series already has this pinned
new_methods.update(
dict(
eq=comp_method(cls, operator.eq, special),
ne=comp_method(cls, operator.ne, special),
lt=comp_method(cls, operator.lt, special),
gt=comp_method(cls, operator.gt, special),
le=comp_method(cls, operator.le, special),
ge=comp_method(cls, operator.ge, special),
)
)
)

if bool_method:
new_methods.update(
Expand Down
Loading