From 3077c3ae2addd5ea8a587ca8b76ba284225d0f21 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 7 Oct 2020 08:10:10 -0700 Subject: [PATCH] TYP/REF: define comparison methods non-dynamically (#36930) --- pandas/core/arraylike.py | 43 ++++++++++++++++ pandas/core/arrays/datetimelike.py | 79 ++++++++++++------------------ pandas/core/base.py | 3 +- pandas/core/generic.py | 2 +- pandas/core/indexes/base.py | 75 ++++++++++++---------------- pandas/core/ops/__init__.py | 31 +----------- pandas/core/ops/methods.py | 23 ++++----- pandas/core/series.py | 17 +++++++ 8 files changed, 138 insertions(+), 135 deletions(-) create mode 100644 pandas/core/arraylike.py diff --git a/pandas/core/arraylike.py b/pandas/core/arraylike.py new file mode 100644 index 0000000000000..1fba022f2a1de --- /dev/null +++ b/pandas/core/arraylike.py @@ -0,0 +1,43 @@ +""" +Methods that can be shared by many array-like classes or subclasses: + Series + Index + ExtensionArray +""" +import operator + +from pandas.errors import AbstractMethodError + +from pandas.core.ops.common import unpack_zerodim_and_defer + + +class OpsMixin: + # ------------------------------------------------------------- + # Comparisons + + def _cmp_method(self, other, op): + raise AbstractMethodError(self) + + @unpack_zerodim_and_defer("__eq__") + def __eq__(self, other): + return self._cmp_method(other, operator.eq) + + @unpack_zerodim_and_defer("__ne__") + def __ne__(self, other): + return self._cmp_method(other, operator.ne) + + @unpack_zerodim_and_defer("__lt__") + def __lt__(self, other): + return self._cmp_method(other, operator.lt) + + @unpack_zerodim_and_defer("__le__") + def __le__(self, other): + return self._cmp_method(other, operator.le) + + @unpack_zerodim_and_defer("__gt__") + def __gt__(self, other): + return self._cmp_method(other, operator.gt) + + @unpack_zerodim_and_defer("__ge__") + def __ge__(self, other): + return self._cmp_method(other, operator.ge) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index a1d6a2c8f4672..6285f142b2391 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -24,7 +24,6 @@ round_nsint64, ) from pandas._typing import DatetimeLikeScalar, DtypeObj -from pandas.compat import set_function_name from pandas.compat.numpy import function as nv from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning from pandas.util._decorators import Appender, Substitution, cache_readonly @@ -51,8 +50,8 @@ from pandas.core import nanops, ops from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts +from pandas.core.arraylike import OpsMixin from pandas.core.arrays._mixins import NDArrayBackedExtensionArray -from pandas.core.arrays.base import ExtensionOpsMixin import pandas.core.common as com from pandas.core.construction import array, extract_array from pandas.core.indexers import check_array_indexer, check_setitem_lengths @@ -73,46 +72,6 @@ class InvalidComparison(Exception): pass -def _datetimelike_array_cmp(cls, op): - """ - Wrap comparison operations to convert Timestamp/Timedelta/Period-like to - boxed scalars/arrays. - """ - opname = f"__{op.__name__}__" - nat_result = opname == "__ne__" - - @unpack_zerodim_and_defer(opname) - def wrapper(self, other): - if self.ndim > 1 and getattr(other, "shape", None) == self.shape: - # TODO: handle 2D-like listlikes - return op(self.ravel(), other.ravel()).reshape(self.shape) - - try: - other = self._validate_comparison_value(other, opname) - except InvalidComparison: - return invalid_comparison(self, other, op) - - dtype = getattr(other, "dtype", None) - if is_object_dtype(dtype): - # We have to use comp_method_OBJECT_ARRAY instead of numpy - # comparison otherwise it would fail to raise when - # comparing tz-aware and tz-naive - with np.errstate(all="ignore"): - result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other) - return result - - other_i8 = self._unbox(other) - result = op(self.asi8, other_i8) - - o_mask = isna(other) - if self._hasnans | np.any(o_mask): - result[self._isnan | o_mask] = nat_result - - return result - - return set_function_name(wrapper, opname, cls) - - class AttributesMixin: _data: np.ndarray @@ -426,9 +385,7 @@ def _with_freq(self, freq): DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin") -class DatetimeLikeArrayMixin( - ExtensionOpsMixin, AttributesMixin, NDArrayBackedExtensionArray -): +class DatetimeLikeArrayMixin(OpsMixin, AttributesMixin, NDArrayBackedExtensionArray): """ Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray @@ -1093,7 +1050,35 @@ def _is_unique(self): # ------------------------------------------------------------------ # Arithmetic Methods - _create_comparison_method = classmethod(_datetimelike_array_cmp) + + def _cmp_method(self, other, op): + if self.ndim > 1 and getattr(other, "shape", None) == self.shape: + # TODO: handle 2D-like listlikes + return op(self.ravel(), other.ravel()).reshape(self.shape) + + try: + other = self._validate_comparison_value(other, f"__{op.__name__}__") + except InvalidComparison: + return invalid_comparison(self, other, op) + + dtype = getattr(other, "dtype", None) + if is_object_dtype(dtype): + # We have to use comp_method_OBJECT_ARRAY instead of numpy + # comparison otherwise it would fail to raise when + # comparing tz-aware and tz-naive + with np.errstate(all="ignore"): + result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other) + return result + + other_i8 = self._unbox(other) + result = op(self.asi8, other_i8) + + o_mask = isna(other) + if self._hasnans | np.any(o_mask): + nat_result = op is operator.ne + result[self._isnan | o_mask] = nat_result + + return result # pow is invalid for all three subclasses; TimedeltaArray will override # the multiplication and division ops @@ -1582,8 +1567,6 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg return self._from_backing_data(result.astype("i8")) -DatetimeLikeArrayMixin._add_comparison_ops() - # ------------------------------------------------------------------- # Shared Constructor Helpers diff --git a/pandas/core/base.py b/pandas/core/base.py index 564a0af3527c5..24bbd28ddc421 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -30,6 +30,7 @@ from pandas.core import algorithms, common as com from pandas.core.accessor import DirNamesMixin from pandas.core.algorithms import duplicated, unique1d, value_counts +from pandas.core.arraylike import OpsMixin from pandas.core.arrays import ExtensionArray from pandas.core.construction import create_series_with_explicit_dtype import pandas.core.nanops as nanops @@ -587,7 +588,7 @@ def _is_builtin_func(self, arg): return self._builtin_table.get(arg, arg) -class IndexOpsMixin: +class IndexOpsMixin(OpsMixin): """ Common ops mixin to support a unified interface / docs for Series / Index """ diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 084a812a60d00..19801025b7672 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1764,7 +1764,7 @@ def _drop_labels_or_levels(self, keys, axis: int = 0): # ---------------------------------------------------------------------- # Iteration - def __hash__(self): + def __hash__(self) -> int: raise TypeError( f"{repr(type(self).__name__)} objects are mutable, " f"thus they cannot be hashed" diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 1e42ebf25c26d..99d9568926df4 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -121,41 +121,6 @@ str_t = str -def _make_comparison_op(op, cls): - def cmp_method(self, other): - if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)): - if other.ndim > 0 and len(self) != len(other): - raise ValueError("Lengths must match to compare") - - if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical): - left = type(other)(self._values, dtype=other.dtype) - return op(left, other) - elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray): - # e.g. PeriodArray - with np.errstate(all="ignore"): - result = op(self._values, other) - - elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex): - # don't pass MultiIndex - with np.errstate(all="ignore"): - result = ops.comp_method_OBJECT_ARRAY(op, self._values, other) - - elif is_interval_dtype(self.dtype): - with np.errstate(all="ignore"): - result = op(self._values, np.asarray(other)) - - else: - with np.errstate(all="ignore"): - result = ops.comparison_op(self._values, np.asarray(other), op) - - if is_bool_dtype(result): - return result - return ops.invalid_comparison(self, other, op) - - name = f"__{op.__name__}__" - return set_function_name(cmp_method, name, cls) - - def _make_arithmetic_op(op, cls): def index_arithmetic_method(self, other): if isinstance(other, (ABCSeries, ABCDataFrame, ABCTimedeltaIndex)): @@ -5400,17 +5365,38 @@ def drop(self, labels, errors: str_t = "raise"): # -------------------------------------------------------------------- # Generated Arithmetic, Comparison, and Unary Methods - @classmethod - def _add_comparison_methods(cls): + def _cmp_method(self, other, op): """ - Add in comparison methods. + Wrapper used to dispatch comparison operations. """ - cls.__eq__ = _make_comparison_op(operator.eq, cls) - cls.__ne__ = _make_comparison_op(operator.ne, cls) - cls.__lt__ = _make_comparison_op(operator.lt, cls) - cls.__gt__ = _make_comparison_op(operator.gt, cls) - cls.__le__ = _make_comparison_op(operator.le, cls) - cls.__ge__ = _make_comparison_op(operator.ge, cls) + if isinstance(other, (np.ndarray, Index, ABCSeries, ExtensionArray)): + if other.ndim > 0 and len(self) != len(other): + raise ValueError("Lengths must match to compare") + + if is_object_dtype(self.dtype) and isinstance(other, ABCCategorical): + left = type(other)(self._values, dtype=other.dtype) + return op(left, other) + elif is_object_dtype(self.dtype) and isinstance(other, ExtensionArray): + # e.g. PeriodArray + with np.errstate(all="ignore"): + result = op(self._values, other) + + elif is_object_dtype(self.dtype) and not isinstance(self, ABCMultiIndex): + # don't pass MultiIndex + with np.errstate(all="ignore"): + result = ops.comp_method_OBJECT_ARRAY(op, self._values, other) + + elif is_interval_dtype(self.dtype): + with np.errstate(all="ignore"): + result = op(self._values, np.asarray(other)) + + else: + with np.errstate(all="ignore"): + result = ops.comparison_op(self._values, np.asarray(other), op) + + if is_bool_dtype(result): + return result + return ops.invalid_comparison(self, other, op) @classmethod def _add_numeric_methods_binary(cls): @@ -5594,7 +5580,6 @@ def shape(self): Index._add_numeric_methods() Index._add_logical_methods() -Index._add_comparison_methods() def ensure_index_from_sequences(sequences, names=None): diff --git a/pandas/core/ops/__init__.py b/pandas/core/ops/__init__.py index cca0e40698ba2..84319b69d9a35 100644 --- a/pandas/core/ops/__init__.py +++ b/pandas/core/ops/__init__.py @@ -20,13 +20,13 @@ from pandas.core import algorithms from pandas.core.construction import extract_array -from pandas.core.ops.array_ops import ( +from pandas.core.ops.array_ops import ( # noqa:F401 arithmetic_op, + comp_method_OBJECT_ARRAY, comparison_op, get_array_op, logical_op, ) -from pandas.core.ops.array_ops import comp_method_OBJECT_ARRAY # noqa:F401 from pandas.core.ops.common import unpack_zerodim_and_defer from pandas.core.ops.docstrings import ( _arith_doc_FRAME, @@ -323,33 +323,6 @@ def wrapper(left, right): return wrapper -def comp_method_SERIES(cls, op, special): - """ - Wrapper function for Series arithmetic operations, to avoid - code duplication. - """ - assert special # non-special uses flex_method_SERIES - op_name = _get_op_name(op, special) - - @unpack_zerodim_and_defer(op_name) - def wrapper(self, other): - - res_name = get_op_result_name(self, other) - - if isinstance(other, ABCSeries) and not self._indexed_same(other): - raise ValueError("Can only compare identically-labeled Series objects") - - lvalues = extract_array(self, extract_numpy=True) - rvalues = extract_array(other, extract_numpy=True) - - res_values = comparison_op(lvalues, rvalues, op) - - return self._construct_result(res_values, name=res_name) - - wrapper.__name__ = op_name - return wrapper - - def bool_method_SERIES(cls, op, special): """ Wrapper function for Series arithmetic operations, to avoid diff --git a/pandas/core/ops/methods.py b/pandas/core/ops/methods.py index 852157e52d5fe..2b117d5e22186 100644 --- a/pandas/core/ops/methods.py +++ b/pandas/core/ops/methods.py @@ -48,7 +48,6 @@ def _get_method_wrappers(cls): arith_method_SERIES, bool_method_SERIES, comp_method_FRAME, - comp_method_SERIES, flex_comp_method_FRAME, flex_method_SERIES, ) @@ -58,7 +57,7 @@ def _get_method_wrappers(cls): arith_flex = flex_method_SERIES comp_flex = flex_method_SERIES arith_special = arith_method_SERIES - comp_special = comp_method_SERIES + comp_special = None bool_special = bool_method_SERIES elif issubclass(cls, ABCDataFrame): arith_flex = arith_method_FRAME @@ -189,16 +188,18 @@ def _create_methods(cls, arith_method, comp_method, bool_method, special): new_methods["divmod"] = arith_method(cls, divmod, special) new_methods["rdivmod"] = arith_method(cls, rdivmod, special) - new_methods.update( - dict( - eq=comp_method(cls, operator.eq, special), - ne=comp_method(cls, operator.ne, special), - lt=comp_method(cls, operator.lt, special), - gt=comp_method(cls, operator.gt, special), - le=comp_method(cls, operator.le, special), - ge=comp_method(cls, operator.ge, special), + if comp_method is not None: + # Series already has this pinned + new_methods.update( + dict( + eq=comp_method(cls, operator.eq, special), + ne=comp_method(cls, operator.ne, special), + lt=comp_method(cls, operator.lt, special), + gt=comp_method(cls, operator.gt, special), + le=comp_method(cls, operator.le, special), + ge=comp_method(cls, operator.ge, special), + ) ) - ) if bool_method: new_methods.update( diff --git a/pandas/core/series.py b/pandas/core/series.py index 2b972d33d7cdd..5cc163807fac6 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -191,6 +191,7 @@ class Series(base.IndexOpsMixin, generic.NDFrame): hasnans = property( base.IndexOpsMixin.hasnans.func, doc=base.IndexOpsMixin.hasnans.__doc__ ) + __hash__ = generic.NDFrame.__hash__ _mgr: SingleBlockManager div: Callable[["Series", Any], "Series"] rdiv: Callable[["Series", Any], "Series"] @@ -4961,6 +4962,22 @@ def to_period(self, freq=None, copy=True) -> "Series": # Add plotting methods to Series hist = pandas.plotting.hist_series + # ---------------------------------------------------------------------- + # Template-Based Arithmetic/Comparison Methods + + def _cmp_method(self, other, op): + res_name = ops.get_op_result_name(self, other) + + if isinstance(other, Series) and not self._indexed_same(other): + raise ValueError("Can only compare identically-labeled Series objects") + + lvalues = extract_array(self, extract_numpy=True) + rvalues = extract_array(other, extract_numpy=True) + + res_values = ops.comparison_op(lvalues, rvalues, op) + + return self._construct_result(res_values, name=res_name) + Series._add_numeric_operations()