Skip to content

Commit

Permalink
TYP/REF: define comparison methods non-dynamically (pandas-dev#36930)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and Kevin D Smith committed Nov 2, 2020
1 parent ffa1e2d commit 3077c3a
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 135 deletions.
43 changes: 43 additions & 0 deletions pandas/core/arraylike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Methods that can be shared by many array-like classes or subclasses:
Series
Index
ExtensionArray
"""
import operator

from pandas.errors import AbstractMethodError

from pandas.core.ops.common import unpack_zerodim_and_defer


class OpsMixin:
# -------------------------------------------------------------
# Comparisons

def _cmp_method(self, other, op):
raise AbstractMethodError(self)

@unpack_zerodim_and_defer("__eq__")
def __eq__(self, other):
return self._cmp_method(other, operator.eq)

@unpack_zerodim_and_defer("__ne__")
def __ne__(self, other):
return self._cmp_method(other, operator.ne)

@unpack_zerodim_and_defer("__lt__")
def __lt__(self, other):
return self._cmp_method(other, operator.lt)

@unpack_zerodim_and_defer("__le__")
def __le__(self, other):
return self._cmp_method(other, operator.le)

@unpack_zerodim_and_defer("__gt__")
def __gt__(self, other):
return self._cmp_method(other, operator.gt)

@unpack_zerodim_and_defer("__ge__")
def __ge__(self, other):
return self._cmp_method(other, operator.ge)
79 changes: 31 additions & 48 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
round_nsint64,
)
from pandas._typing import DatetimeLikeScalar, DtypeObj
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
from pandas.util._decorators import Appender, Substitution, cache_readonly
Expand All @@ -51,8 +50,8 @@

from pandas.core import nanops, ops
from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.arrays.base import ExtensionOpsMixin
import pandas.core.common as com
from pandas.core.construction import array, extract_array
from pandas.core.indexers import check_array_indexer, check_setitem_lengths
Expand All @@ -73,46 +72,6 @@ class InvalidComparison(Exception):
pass


def _datetimelike_array_cmp(cls, op):
"""
Wrap comparison operations to convert Timestamp/Timedelta/Period-like to
boxed scalars/arrays.
"""
opname = f"__{op.__name__}__"
nat_result = opname == "__ne__"

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
# TODO: handle 2D-like listlikes
return op(self.ravel(), other.ravel()).reshape(self.shape)

try:
other = self._validate_comparison_value(other, opname)
except InvalidComparison:
return invalid_comparison(self, other, op)

dtype = getattr(other, "dtype", None)
if is_object_dtype(dtype):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
return result

other_i8 = self._unbox(other)
result = op(self.asi8, other_i8)

o_mask = isna(other)
if self._hasnans | np.any(o_mask):
result[self._isnan | o_mask] = nat_result

return result

return set_function_name(wrapper, opname, cls)


class AttributesMixin:
_data: np.ndarray

Expand Down Expand Up @@ -426,9 +385,7 @@ def _with_freq(self, freq):
DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin")


class DatetimeLikeArrayMixin(
ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray
):
class DatetimeLikeArrayMixin(OpsMixin, AttributesMixin, NDArrayBackedExtensionArray):
"""
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray
Expand Down Expand Up @@ -1093,7 +1050,35 @@ def _is_unique(self):

# ------------------------------------------------------------------
# Arithmetic Methods
_create_comparison_method = classmethod(_datetimelike_array_cmp)

def _cmp_method(self, other, op):
if self.ndim > 1 and getattr(other, "shape", None) == self.shape:
# TODO: handle 2D-like listlikes
return op(self.ravel(), other.ravel()).reshape(self.shape)

try:
other = self._validate_comparison_value(other, f"__{op.__name__}__")
except InvalidComparison:
return invalid_comparison(self, other, op)

dtype = getattr(other, "dtype", None)
if is_object_dtype(dtype):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
return result

other_i8 = self._unbox(other)
result = op(self.asi8, other_i8)

o_mask = isna(other)
if self._hasnans | np.any(o_mask):
nat_result = op is operator.ne
result[self._isnan | o_mask] = nat_result

return result

# pow is invalid for all three subclasses; TimedeltaArray will override
# the multiplication and division ops
Expand Down Expand Up @@ -1582,8 +1567,6 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg
return self._from_backing_data(result.astype("i8"))


DatetimeLikeArrayMixin._add_comparison_ops()

# -------------------------------------------------------------------
# Shared Constructor Helpers

Expand Down
3 changes: 2 additions & 1 deletion pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pandas.core import algorithms, common as com
from pandas.core.accessor import DirNamesMixin
from pandas.core.algorithms import duplicated, unique1d, value_counts
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import ExtensionArray
from pandas.core.construction import create_series_with_explicit_dtype
import pandas.core.nanops as nanops
Expand Down Expand Up @@ -587,7 +588,7 @@ def _is_builtin_func(self, arg):
return self._builtin_table.get(arg, arg)


class IndexOpsMixin:
class IndexOpsMixin(OpsMixin):
"""
Common ops mixin to support a unified interface / docs for Series / Index
"""
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,7 @@ def _drop_labels_or_levels(self, keys, axis: int = 0):
# ----------------------------------------------------------------------
# Iteration

def __hash__(self):
def __hash__(self) -> int:
raise TypeError(
f"{repr(type(self).__name__)} objects are mutable, "
f"thus they cannot be hashed"
Expand Down
75 changes: 30 additions & 45 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,41 +121,6 @@
str_t = str


def _make_comparison_op(op, cls):
def cmp_method(self, other):
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
if other.ndim > 0 and len(self) != len(other):
raise ValueError("Lengths must match to compare")

if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
left = type(other)(self._values, dtype=other.dtype)
return op(left, other)
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
# e.g. PeriodArray
with np.errstate(all="ignore"):
result = op(self._values, other)

elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
# don't pass MultiIndex
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)

elif is_interval_dtype(self.dtype):
with np.errstate(all="ignore"):
result = op(self._values, np.asarray(other))

else:
with np.errstate(all="ignore"):
result = ops.comparison_op(self._values, np.asarray(other), op)

if is_bool_dtype(result):
return result
return ops.invalid_comparison(self, other, op)

name = f"__{op.__name__}__"
return set_function_name(cmp_method, name, cls)


def _make_arithmetic_op(op, cls):
def index_arithmetic_method(self, other):
if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)):
Expand Down Expand Up @@ -5400,17 +5365,38 @@ def drop(self, labels, errors: str_t = "raise"):
# --------------------------------------------------------------------
# Generated Arithmetic, Comparison, and Unary Methods

@classmethod
def _add_comparison_methods(cls):
def _cmp_method(self, other, op):
"""
Add in comparison methods.
Wrapper used to dispatch comparison operations.
"""
cls.__eq__ = _make_comparison_op(operator.eq, cls)
cls.__ne__ = _make_comparison_op(operator.ne, cls)
cls.__lt__ = _make_comparison_op(operator.lt, cls)
cls.__gt__ = _make_comparison_op(operator.gt, cls)
cls.__le__ = _make_comparison_op(operator.le, cls)
cls.__ge__ = _make_comparison_op(operator.ge, cls)
if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)):
if other.ndim > 0 and len(self) != len(other):
raise ValueError("Lengths must match to compare")

if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical):
left = type(other)(self._values, dtype=other.dtype)
return op(left, other)
elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray):
# e.g. PeriodArray
with np.errstate(all="ignore"):
result = op(self._values, other)

elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex):
# don't pass MultiIndex
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(op, self._values, other)

elif is_interval_dtype(self.dtype):
with np.errstate(all="ignore"):
result = op(self._values, np.asarray(other))

else:
with np.errstate(all="ignore"):
result = ops.comparison_op(self._values, np.asarray(other), op)

if is_bool_dtype(result):
return result
return ops.invalid_comparison(self, other, op)

@classmethod
def _add_numeric_methods_binary(cls):
Expand Down Expand Up @@ -5594,7 +5580,6 @@ def shape(self):

Index._add_numeric_methods()
Index._add_logical_methods()
Index._add_comparison_methods()


def ensure_index_from_sequences(sequences, names=None):
Expand Down
31 changes: 2 additions & 29 deletions pandas/core/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@

from pandas.core import algorithms
from pandas.core.construction import extract_array
from pandas.core.ops.array_ops import (
from pandas.core.ops.array_ops import ( # noqa:F401
arithmetic_op,
comp_method_OBJECT_ARRAY,
comparison_op,
get_array_op,
logical_op,
)
from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.ops.docstrings import (
_arith_doc_FRAME,
Expand Down Expand Up @@ -323,33 +323,6 @@ def wrapper(left, right):
return wrapper


def comp_method_SERIES(cls, op, special):
"""
Wrapper function for Series arithmetic operations, to avoid
code duplication.
"""
assert special # non-special uses flex_method_SERIES
op_name = _get_op_name(op, special)

@unpack_zerodim_and_defer(op_name)
def wrapper(self, other):

res_name = get_op_result_name(self, other)

if isinstance(other, ABCSeries) and not self._indexed_same(other):
raise ValueError("Can only compare identically-labeled Series objects")

lvalues = extract_array(self, extract_numpy=True)
rvalues = extract_array(other, extract_numpy=True)

res_values = comparison_op(lvalues, rvalues, op)

return self._construct_result(res_values, name=res_name)

wrapper.__name__ = op_name
return wrapper


def bool_method_SERIES(cls, op, special):
"""
Wrapper function for Series arithmetic operations, to avoid
Expand Down
23 changes: 12 additions & 11 deletions pandas/core/ops/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _get_method_wrappers(cls):
arith_method_SERIES,
bool_method_SERIES,
comp_method_FRAME,
comp_method_SERIES,
flex_comp_method_FRAME,
flex_method_SERIES,
)
Expand All @@ -58,7 +57,7 @@ def _get_method_wrappers(cls):
arith_flex = flex_method_SERIES
comp_flex = flex_method_SERIES
arith_special = arith_method_SERIES
comp_special = comp_method_SERIES
comp_special = None
bool_special = bool_method_SERIES
elif issubclass(cls, ABCDataFrame):
arith_flex = arith_method_FRAME
Expand Down Expand Up @@ -189,16 +188,18 @@ def _create_methods(cls, arith_method, comp_method, bool_method, special):
new_methods["divmod"] = arith_method(cls, divmod, special)
new_methods["rdivmod"] = arith_method(cls, rdivmod, special)

new_methods.update(
dict(
eq=comp_method(cls, operator.eq, special),
ne=comp_method(cls, operator.ne, special),
lt=comp_method(cls, operator.lt, special),
gt=comp_method(cls, operator.gt, special),
le=comp_method(cls, operator.le, special),
ge=comp_method(cls, operator.ge, special),
if comp_method is not None:
# Series already has this pinned
new_methods.update(
dict(
eq=comp_method(cls, operator.eq, special),
ne=comp_method(cls, operator.ne, special),
lt=comp_method(cls, operator.lt, special),
gt=comp_method(cls, operator.gt, special),
le=comp_method(cls, operator.le, special),
ge=comp_method(cls, operator.ge, special),
)
)
)

if bool_method:
new_methods.update(
Expand Down
Loading

0 comments on commit 3077c3a

Please sign in to comment.