Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX-#5238: Make rmul really rmul instead of mul. #5246

Merged
merged 3 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,14 @@ def mod(self, other, **kwargs): # noqa: PR02
def mul(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.mul)(self, other=other, **kwargs)

@doc_utils.doc_binary_method(
operation="multiplication", sign="*", self_on_right=True
)
def rmul(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.rmul)(
self, other=other, **kwargs
)

@doc_utils.add_refer_to("DataFrame.corr")
def corr(self, **kwargs): # noqa: PR02
"""
Expand Down
1 change: 1 addition & 0 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def to_numpy(self, **kwargs):
lt = Binary.register(pandas.DataFrame.lt)
mod = Binary.register(pandas.DataFrame.mod)
mul = Binary.register(pandas.DataFrame.mul)
rmul = Binary.register(pandas.DataFrame.rmul)
ne = Binary.register(pandas.DataFrame.ne)
pow = Binary.register(pandas.DataFrame.pow)
radd = Binary.register(pandas.DataFrame.radd)
Expand Down
10 changes: 9 additions & 1 deletion modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,7 +2322,15 @@ def rmod(
"rmod", other, axis=axis, level=level, fill_value=fill_value
)

rmul = mul
def rmul(
self, other, axis="columns", level=None, fill_value=None
): # noqa: PR01, RT01, D200
"""
Get Multiplication of dataframe and other, element-wise (binary operator `rmul`).
mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._binary_op(
"rmul", other, axis=axis, level=level, fill_value=fill_value
)

def _rolling(
self, window, min_periods, center, win_type, *args, **kwargs
Expand Down
17 changes: 16 additions & 1 deletion modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,7 +1525,22 @@ def mul(
broadcast=isinstance(other, Series),
)

rmul = multiply = mul
multiply = mul

def rmul(
self, other, axis="columns", level=None, fill_value=None
): # noqa: PR01, RT01, D200
"""
Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`).
mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._binary_op(
"rmul",
other,
axis=axis,
level=level,
fill_value=fill_value,
broadcast=isinstance(other, Series),
)

def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200
"""
Expand Down
13 changes: 12 additions & 1 deletion modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,18 @@ def mul(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01,
new_other, level=level, fill_value=None, axis=axis
)

multiply = rmul = mul
multiply = mul

def rmul(
self, other, level=None, fill_value=None, axis=0
): # noqa: PR01, RT01, D200
"""
Return multiplication of series and `other`, element-wise (binary operator `mul`).
mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rmul(
new_other, level=level, fill_value=None, axis=axis
)

def ne(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200
"""
Expand Down
20 changes: 20 additions & 0 deletions modin/pandas/test/dataframe/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
create_test_dfs,
default_to_pandas_ignore_string,
CustomIntegerForAddition,
NonCommutativeMultiplyInteger,
)
from modin.config import Engine, NPartitions
from modin.test.test_utils import warns_that_defaulting_to_pandas
Expand Down Expand Up @@ -353,3 +354,22 @@ def test_add_custom_class():
*create_test_dfs(test_data["int_data"]),
lambda df: df + CustomIntegerForAddition(4),
)


def test_non_commutative_multiply_pandas():
# The non commutative integer class implementation is tricky. Check that
# multiplying such an integer with a pandas dataframe is really not
# commutative.
pandas_df = pd.DataFrame([[1]], dtype=int)
YarShev marked this conversation as resolved.
Show resolved Hide resolved
integer = NonCommutativeMultiplyInteger(2)
assert not (integer * pandas_df).equals(pandas_df * integer)


def test_non_commutative_multiply():
# This test checks that mul and rmul do different things when
# multiplication is not commutative, e.g. for adding a string to a string.
mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
# For context see https://github.com/modin-project/modin/issues/5238
modin_df, pandas_df = create_test_dfs([1], dtype=int)
integer = NonCommutativeMultiplyInteger(2)
eval_general(modin_df, pandas_df, lambda s: integer * s)
eval_general(modin_df, pandas_df, lambda s: s * integer)
25 changes: 24 additions & 1 deletion modin/pandas/test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
test_data_large_categorical_series_values,
default_to_pandas_ignore_string,
CustomIntegerForAddition,
NonCommutativeMultiplyInteger,
)
from modin.config import NPartitions

Expand Down Expand Up @@ -4258,11 +4259,33 @@ def test_encode(data, encoding_type):


@pytest.mark.parametrize("data", test_string_data_values, ids=test_string_data_keys)
def test_add_string_to_series(data):
def test_non_commutative_add_string_to_series(data):
# This test checks that add and radd do different things when addition is
# not commutative, e.g. for adding a string to a string. For context see
# https://github.com/modin-project/modin/issues/4908
eval_general(*create_test_series(data), lambda s: "string" + s)
eval_general(*create_test_series(data), lambda s: s + "string")


def test_non_commutative_multiply_pandas():
# The non commutative integer class implementation is tricky. Check that
# multiplying such an integer with a pandas series is really not
# commutative.
pandas_series = pd.DataFrame([[1]], dtype=int)
integer = NonCommutativeMultiplyInteger(2)
assert not (integer * pandas_series).equals(pandas_series * integer)
mvashishtha marked this conversation as resolved.
Show resolved Hide resolved


def test_non_commutative_multiply():
# This test checks that mul and rmul do different things when
# multiplication is not commutative, e.g. for adding a string to a string.
# For context see https://github.com/modin-project/modin/issues/5238
modin_series, pandas_series = create_test_series(1, dtype=int)
integer = NonCommutativeMultiplyInteger(2)
eval_general(modin_series, pandas_series, lambda s: integer * s)
eval_general(modin_series, pandas_series, lambda s: s * integer)


@pytest.mark.parametrize(
"is_sparse_data", [True, False], ids=["is_sparse", "is_not_sparse"]
)
Expand Down
32 changes: 32 additions & 0 deletions modin/pandas/test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,38 @@ def __radd__(self, other):
return other + self.value


class NonCommutativeMultiplyInteger:
"""int-like class with non-commutative multiply operation.

We need to test that rmul and mul do different things even when
multiplication is not commutative, but almost all multiplication is
commutative. This class' fake multiplication overloads are not commutative
when you multiply an instance of this class with pandas.series, which
does not know how to __mul__ with this class. e.g.

NonCommutativeMultiplyInteger(2) * pd.Series(1, dtype=int) == pd.Series(2, dtype=int)
pd.Series(1, dtype=int) * NonCommutativeMultiplyInteger(2) == pd.Series(3, dtype=int)
"""

def __init__(self, value: int):
if not isinstance(value, int):
raise TypeError(
f"must initialize with integer, but got {value} of type {type(value)}"
)
self.value = value

def __mul__(self, other):
# Note that we need to check other is an int, otherwise when we (left) mul
# this with a series, we'll just multiply self.value by the series, whereas
# we want to make the series do an rmul instead.
if not isinstance(other, int):
return NotImplemented
return self.value * other

def __rmul__(self, other):
return self.value * other + 1
YarShev marked this conversation as resolved.
Show resolved Hide resolved


def categories_equals(left, right):
assert (left.ordered and right.ordered) or (not left.ordered and not right.ordered)
assert_extension_array_equal(left, right)
Expand Down