Skip to content

Commit

Permalink
De-duplicate dispatch code, remove unreachable branches (#22068)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Aug 8, 2018
1 parent ed35aef commit 81f386c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 50 deletions.
2 changes: 1 addition & 1 deletion pandas/_libs/tslibs/timedeltas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ cdef class _Timedelta(timedelta):
def nanoseconds(self):
"""
Return the number of nanoseconds (n), where 0 <= n < 1 microsecond.
Returns
-------
int
Expand Down
42 changes: 9 additions & 33 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4911,20 +4911,7 @@ def _arith_op(left, right):

if this._is_mixed_type or other._is_mixed_type:
# iterate over columns
if this.columns.is_unique:
# unique columns
result = {col: _arith_op(this[col], other[col])
for col in this}
result = self._constructor(result, index=new_index,
columns=new_columns, copy=False)
else:
# non-unique columns
result = {i: _arith_op(this.iloc[:, i], other.iloc[:, i])
for i, col in enumerate(this.columns)}
result = self._constructor(result, index=new_index, copy=False)
result.columns = new_columns
return result

return ops.dispatch_to_series(this, other, _arith_op)
else:
result = _arith_op(this.values, other.values)

Expand Down Expand Up @@ -4958,27 +4945,16 @@ def _compare_frame(self, other, func, str_rep):
# compare_frame assumes self._indexed_same(other)

import pandas.core.computation.expressions as expressions
# unique
if self.columns.is_unique:

def _compare(a, b):
return {col: func(a[col], b[col]) for col in a.columns}

new_data = expressions.evaluate(_compare, str_rep, self, other)
return self._constructor(data=new_data, index=self.index,
columns=self.columns, copy=False)
# non-unique
else:

def _compare(a, b):
return {i: func(a.iloc[:, i], b.iloc[:, i])
for i, col in enumerate(a.columns)}
def _compare(a, b):
return {i: func(a.iloc[:, i], b.iloc[:, i])
for i in range(len(a.columns))}

new_data = expressions.evaluate(_compare, str_rep, self, other)
result = self._constructor(data=new_data, index=self.index,
copy=False)
result.columns = self.columns
return result
new_data = expressions.evaluate(_compare, str_rep, self, other)
result = self._constructor(data=new_data, index=self.index,
copy=False)
result.columns = self.columns
return result

def combine(self, other, func, fill_value=None, overwrite=True):
"""
Expand Down
61 changes: 45 additions & 16 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,7 @@ def na_op(x, y):
result[mask] = op(x[mask], com.values_from_object(y[mask]))
else:
assert isinstance(x, np.ndarray)
assert is_scalar(y)
result = np.empty(len(x), dtype=x.dtype)
mask = notna(x)
result[mask] = op(x[mask], y)
Expand Down Expand Up @@ -1189,6 +1190,7 @@ def wrapper(left, right):

elif (is_extension_array_dtype(left) or
is_extension_array_dtype(right)):
# TODO: should this include `not is_scalar(right)`?
return dispatch_to_extension_op(op, left, right)

elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
Expand Down Expand Up @@ -1278,13 +1280,11 @@ def na_op(x, y):
# should have guarantess on what x, y can be type-wise
# Extension Dtypes are not called here

# dispatch to the categorical if we have a categorical
# in either operand
if is_categorical_dtype(y) and not is_scalar(y):
# The `not is_scalar(y)` check excludes the string "category"
return op(y, x)
# Checking that cases that were once handled here are no longer
# reachable.
assert not (is_categorical_dtype(y) and not is_scalar(y))

elif is_object_dtype(x.dtype):
if is_object_dtype(x.dtype):
result = _comp_method_OBJECT_ARRAY(op, x, y)

elif is_datetimelike_v_numeric(x, y):
Expand Down Expand Up @@ -1342,7 +1342,7 @@ def wrapper(self, other, axis=None):
return self._constructor(res_values, index=self.index,
name=res_name)

if is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
elif is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
# Dispatch to DatetimeIndex to ensure identical
# Series/Index behavior
if (isinstance(other, datetime.date) and
Expand Down Expand Up @@ -1384,8 +1384,9 @@ def wrapper(self, other, axis=None):
name=res_name)

elif (is_extension_array_dtype(self) or
(is_extension_array_dtype(other) and
not is_scalar(other))):
(is_extension_array_dtype(other) and not is_scalar(other))):
# Note: the `not is_scalar(other)` condition rules out
# e.g. other == "category"
return dispatch_to_extension_op(op, self, other)

elif isinstance(other, ABCSeries):
Expand All @@ -1408,13 +1409,6 @@ def wrapper(self, other, axis=None):
# is not.
return result.__finalize__(self).rename(res_name)

elif isinstance(other, pd.Categorical):
# ordering of checks matters; by this point we know
# that not is_categorical_dtype(self)
res_values = op(self.values, other)
return self._constructor(res_values, index=self.index,
name=res_name)

elif is_scalar(other) and isna(other):
# numpy does not like comparisons vs None
if op is operator.ne:
Expand Down Expand Up @@ -1544,6 +1538,41 @@ def flex_wrapper(self, other, level=None, fill_value=None, axis=0):
# -----------------------------------------------------------------------------
# DataFrame

def dispatch_to_series(left, right, func):
"""
Evaluate the frame operation func(left, right) by evaluating
column-by-column, dispatching to the Series implementation.
Parameters
----------
left : DataFrame
right : scalar or DataFrame
func : arithmetic or comparison operator
Returns
-------
DataFrame
"""
# Note: we use iloc to access columns for compat with cases
# with non-unique columns.
if lib.is_scalar(right):
new_data = {i: func(left.iloc[:, i], right)
for i in range(len(left.columns))}
elif isinstance(right, ABCDataFrame):
assert right._indexed_same(left)
new_data = {i: func(left.iloc[:, i], right.iloc[:, i])
for i in range(len(left.columns))}
else:
# Remaining cases have less-obvious dispatch rules
raise NotImplementedError

result = left._constructor(new_data, index=left.index, copy=False)
# Pin columns instead of passing to constructor for compat with
# non-unique columns case
result.columns = left.columns
return result


def _combine_series_frame(self, other, func, fill_value=None, axis=None,
level=None, try_cast=True):
"""
Expand Down

0 comments on commit 81f386c

Please sign in to comment.