Skip to content

Commit

Permalink
FIX-#5236: Allow binary operations with custom classes. (#5237)
Browse files Browse the repository at this point in the history
Signed-off-by: mvashishtha <mahesh@ponder.io>
  • Loading branch information
mvashishtha authored Nov 21, 2022
1 parent 4779473 commit 5acf539
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 35 deletions.
68 changes: 33 additions & 35 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,47 +245,42 @@ def _validate_other(
TypeError
If any validation checks fail.
"""
# We skip dtype checking if the other is a scalar.
if is_scalar(other):
if isinstance(other, BasePandasDataset):
return other._query_compiler
if not is_list_like(other):
# We skip dtype checking if the other is a scalar. Note that pandas
# is_scalar can be misleading as it is False for almost all objects,
# even when those objects should be treated as scalars. See e.g.
# https://github.com/modin-project/modin/issues/5236. Therefore, we
# detect scalars by checking that `other` is neither a list-like nor
# another BasePandasDataset.
return other
axis = self._get_axis_number(axis) if axis is not None else 1
result = other
if isinstance(other, BasePandasDataset):
return other._query_compiler
elif is_list_like(other):
if axis == 0:
if len(other) != len(self._query_compiler.index):
raise ValueError(
f"Unable to coerce to Series, length must be {len(self._query_compiler.index)}: "
+ f"given {len(other)}"
)
else:
if len(other) != len(self._query_compiler.columns):
raise ValueError(
f"Unable to coerce to Series, length must be {len(self._query_compiler.columns)}: "
+ f"given {len(other)}"
)
if hasattr(other, "dtype"):
other_dtypes = [other.dtype] * len(other)
elif is_dict_like(other):
other_dtypes = [
type(other[label])
for label in self._query_compiler.get_axis(axis)
# The binary operation is applied for intersection of axis labels
# and dictionary keys. So filtering out extra keys.
if label in other
]
else:
other_dtypes = [type(x) for x in other]
if axis == 0:
if len(other) != len(self._query_compiler.index):
raise ValueError(
f"Unable to coerce to Series, length must be {len(self._query_compiler.index)}: "
+ f"given {len(other)}"
)
else:
other_dtypes = [
type(other)
for _ in range(
len(self._query_compiler.index)
if axis
else len(self._query_compiler.columns)
if len(other) != len(self._query_compiler.columns):
raise ValueError(
f"Unable to coerce to Series, length must be {len(self._query_compiler.columns)}: "
+ f"given {len(other)}"
)
if hasattr(other, "dtype"):
other_dtypes = [other.dtype] * len(other)
elif is_dict_like(other):
other_dtypes = [
type(other[label])
for label in self._query_compiler.get_axis(axis)
# The binary operation is applied for intersection of axis labels
# and dictionary keys. So filtering out extra keys.
if label in other
]
else:
other_dtypes = [type(x) for x in other]
if compare_index:
if not self.index.equals(other.index):
raise TypeError("Cannot perform operation with non-equal index")
Expand All @@ -304,6 +299,9 @@ def _validate_other(
if label in other
]

# TODO(https://github.com/modin-project/modin/issues/5239):
# this spuriously rejects other that is a list including some
# custom type that can be added to self's elements.
if not all(
(is_numeric_dtype(self_dtype) and is_numeric_dtype(other_dtype))
or (is_object_dtype(self_dtype) and is_object_dtype(other_dtype))
Expand Down
17 changes: 17 additions & 0 deletions modin/pandas/test/dataframe/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import matplotlib
import modin.pandas as pd

from modin._compat import PandasCompatVersion
from modin.core.dataframe.pandas.partitioning.axis_partition import (
PandasDataframeAxisPartition,
)
Expand All @@ -27,6 +28,7 @@
test_data,
create_test_dfs,
default_to_pandas_ignore_string,
CustomIntegerForAddition,
)
from modin.config import Engine, NPartitions
from modin.test.test_utils import warns_that_defaulting_to_pandas
Expand Down Expand Up @@ -336,3 +338,18 @@ def test_add_string_to_df():
modin_df, pandas_df = create_test_dfs(["a", "b"])
eval_general(modin_df, pandas_df, lambda df: "string" + df)
eval_general(modin_df, pandas_df, lambda df: df + "string")


@pytest.mark.xfail(
PandasCompatVersion.CURRENT == PandasCompatVersion.PY36,
reason="Seems to be a bug in pandas 1.1.5. pandas throws ValueError "
+ "for this particular dataframe.",
)
def test_add_custom_class():
# see https://github.com/modin-project/modin/issues/5236
# Test that we can add any object that is addable to pandas object data
# via "+".
eval_general(
*create_test_dfs(test_data["int_data"]),
lambda df: df + CustomIntegerForAddition(4),
)
11 changes: 11 additions & 0 deletions modin/pandas/test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
test_data_large_categorical_series_keys,
test_data_large_categorical_series_values,
default_to_pandas_ignore_string,
CustomIntegerForAddition,
)
from modin.config import NPartitions

Expand Down Expand Up @@ -635,6 +636,16 @@ def test_add_suffix(data):
)


def test_add_custom_class():
# see https://github.com/modin-project/modin/issues/5236
# Test that we can add any object that is addable to pandas object data
# via "+".
eval_general(
*create_test_series(test_data["int_data"]),
lambda df: df + CustomIntegerForAddition(4),
)


@pytest.mark.parametrize("data", test_data_values, ids=test_data_keys)
@pytest.mark.parametrize("func", agg_func_values, ids=agg_func_keys)
def test_agg(data, func):
Expand Down
11 changes: 11 additions & 0 deletions modin/pandas/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,17 @@
time_parsing_csv_path = "modin/pandas/test/data/test_time_parsing.csv"


class CustomIntegerForAddition:
def __init__(self, value: int):
self.value = value

def __add__(self, other):
return self.value + other

def __radd__(self, other):
return other + self.value


def categories_equals(left, right):
assert (left.ordered and right.ordered) or (not left.ordered and not right.ordered)
assert_extension_array_equal(left, right)
Expand Down

0 comments on commit 5acf539

Please sign in to comment.