Skip to content

Commit

Permalink
FIX-modin-project#7381: Fix Series binary operators ignoring fillna
Browse files Browse the repository at this point in the history
Signed-off-by: Jonathan Shi <jhshi07@gmail.com>
  • Loading branch information
noloerino committed Sep 12, 2024
1 parent 3357709 commit abfd048
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 18 deletions.
54 changes: 54 additions & 0 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,15 @@ def combine_first(self, other, **kwargs): # noqa: PR02
def eq(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.eq)(self, other=other, **kwargs)

def series_eq(self, other, **kwargs):
return BinaryDefault.register(pandas.Series.eq)(
self,
other=other,
squeeze_self=True,
squeeze_other=True,
**kwargs,
)

@doc_utils.add_refer_to("DataFrame.equals")
def equals(self, other): # noqa: PR01, RT01
return BinaryDefault.register(pandas.DataFrame.equals)(self, other=other)
Expand Down Expand Up @@ -685,24 +694,60 @@ def divmod(self, other, **kwargs):
def ge(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.ge)(self, other=other, **kwargs)

def series_ge(self, other, **kwargs):
return BinaryDefault.register(pandas.Series.ge)(
self,
other=other,
squeeze_self=True,
squeeze_other=True,
**kwargs,
)

@doc_utils.doc_binary_method(
operation="greater than comparison", sign=">", op_type="comparison"
)
def gt(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.gt)(self, other=other, **kwargs)

def series_gt(self, other, **kwargs):
return BinaryDefault.register(pandas.Series.gt)(
self,
other=other,
squeeze_self=True,
squeeze_other=True,
**kwargs,
)

@doc_utils.doc_binary_method(
operation="less than or equal comparison", sign="<=", op_type="comparison"
)
def le(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.le)(self, other=other, **kwargs)

def series_le(self, other, **kwargs):
return BinaryDefault.register(pandas.Series.le)(
self,
other=other,
squeeze_self=True,
squeeze_other=True,
**kwargs,
)

@doc_utils.doc_binary_method(
operation="less than comparison", sign="<", op_type="comparison"
)
def lt(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.lt)(self, other=other, **kwargs)

def series_lt(self, other, **kwargs):
return BinaryDefault.register(pandas.Series.lt)(
self,
other=other,
squeeze_self=True,
squeeze_other=True,
**kwargs,
)

@doc_utils.doc_binary_method(operation="modulo", sign="%")
def mod(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.mod)(self, other=other, **kwargs)
Expand Down Expand Up @@ -818,6 +863,15 @@ def dot(self, other, **kwargs): # noqa: PR02
def ne(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.ne)(self, other=other, **kwargs)

def series_ne(self, other, **kwargs):
return BinaryDefault.register(pandas.Series.ne)(
self,
other=other,
squeeze_self=True,
squeeze_other=True,
**kwargs,
)

@doc_utils.doc_binary_method(operation="exponential power", sign="**")
def pow(self, other, **kwargs): # noqa: PR02
return BinaryDefault.register(pandas.DataFrame.pow)(self, other=other, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ def caller(df, *args, **kwargs):
return caller


def _series_logical_binop(func):
"""
Build a callable function to pass to Binary.register for Series logical operators.
"""
return lambda x, y, **kwargs: func(x.squeeze(axis=1), y.squeeze(axis=1), **kwargs).to_frame()


@_inherit_docstrings(BaseQueryCompiler)
class PandasQueryCompiler(BaseQueryCompiler, QueryCompilerCaster):
"""
Expand Down Expand Up @@ -522,6 +529,14 @@ def to_numpy(self, **kwargs):
sort=False,
)

# Series logical operators take an additional fill_value flag that dataframe does not
series_eq = Binary.register(_series_logical_binop(pandas.Series.eq), infer_dtypes="bool")
series_ge = Binary.register(_series_logical_binop(pandas.Series.ge), infer_dtypes="bool")
series_gt = Binary.register(_series_logical_binop(pandas.Series.gt), infer_dtypes="bool")
series_le = Binary.register(_series_logical_binop(pandas.Series.le), infer_dtypes="bool")
series_lt = Binary.register(_series_logical_binop(pandas.Series.lt), infer_dtypes="bool")
series_ne = Binary.register(_series_logical_binop(pandas.Series.ne), infer_dtypes="bool")

# Needed for numpy API
_logical_and = Binary.register(
lambda df, other, *args, **kwargs: pandas.DataFrame(
Expand Down
11 changes: 11 additions & 0 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,17 @@ def _binary_op(self, op, other, **kwargs) -> Self:
]
if op in exclude_list:
kwargs.pop("axis")
# Series logical operations take an additional fill_value argument that DF does not
series_specialize_list = [
"eq",
"ge",
"gt",
"le",
"lt",
"ne",
]
if not self._is_dataframe and op in series_specialize_list:
op = "series_" + op
new_query_compiler = getattr(self._query_compiler, op)(other, **kwargs)
return self._create_or_update_from_compiler(new_query_compiler)

Expand Down
36 changes: 18 additions & 18 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def eq(
Return Equal to of series and `other`, element-wise (binary operator `eq`).
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).eq(new_other, level=level, axis=axis)
return new_self._binary_op("eq", new_other, level=level, fill_value=fill_value, axis=axis)

def equals(self, other) -> bool: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def floordiv(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).floordiv(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def ge(
Expand All @@ -1140,7 +1140,7 @@ def ge(
Return greater than or equal to of series and `other`, element-wise (binary operator `ge`).
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).ge(new_other, level=level, axis=axis)
return new_self._binary_op("ge", new_other, level=level, fill_value=fill_value, axis=axis)

def groupby(
self,
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def gt(
Return greater than of series and `other`, element-wise (binary operator `gt`).
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).gt(new_other, level=level, axis=axis)
return new_self._binary_op("gt", new_other, level=level, fill_value=fill_value, axis=axis)

def hist(
self,
Expand Down Expand Up @@ -1306,7 +1306,7 @@ def le(
Return less than or equal to of series and `other`, element-wise (binary operator `le`).
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).le(new_other, level=level, axis=axis)
return new_self._binary_op("le", new_other, level=level, fill_value=fill_value, axis=axis)

def lt(
self, other, level=None, fill_value=None, axis=0
Expand All @@ -1315,7 +1315,7 @@ def lt(
Return less than of series and `other`, element-wise (binary operator `lt`).
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).lt(new_other, level=level, axis=axis)
return new_self._binary_op("lt", new_other, level=level, fill_value=fill_value, axis=axis)

def map(self, arg, na_action=None) -> Series: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -1402,7 +1402,7 @@ def mod(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).mod(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def mode(self, dropna=True) -> Series: # noqa: PR01, RT01, D200
Expand All @@ -1419,7 +1419,7 @@ def mul(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).mul(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

multiply = mul
Expand All @@ -1432,7 +1432,7 @@ def rmul(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rmul(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def ne(
Expand All @@ -1442,7 +1442,7 @@ def ne(
Return not equal to of series and `other`, element-wise (binary operator `ne`).
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).ne(new_other, level=level, axis=axis)
return new_self._binary_op("ne", new_other, level=level, fill_value=fill_value, axis=axis)

def nlargest(self, n=5, keep="first") -> Series: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -1562,7 +1562,7 @@ def pow(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).pow(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

@_inherit_docstrings(pandas.Series.prod, apilink="pandas.Series.prod")
Expand Down Expand Up @@ -1763,7 +1763,7 @@ def rfloordiv(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rfloordiv(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def rmod(
Expand All @@ -1774,7 +1774,7 @@ def rmod(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rmod(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def rpow(
Expand All @@ -1785,7 +1785,7 @@ def rpow(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rpow(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def rsub(
Expand All @@ -1796,7 +1796,7 @@ def rsub(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rsub(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

def rtruediv(
Expand All @@ -1807,7 +1807,7 @@ def rtruediv(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).rtruediv(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

rdiv = rtruediv
Expand Down Expand Up @@ -1955,7 +1955,7 @@ def sub(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).sub(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

subtract = sub
Expand Down Expand Up @@ -2130,7 +2130,7 @@ def truediv(
"""
new_self, new_other = self._prepare_inter_op(other)
return super(Series, new_self).truediv(
new_other, level=level, fill_value=None, axis=axis
new_other, level=level, fill_value=fill_value, axis=axis
)

div = divide = truediv
Expand Down
33 changes: 33 additions & 0 deletions modin/tests/pandas/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5094,3 +5094,36 @@ def test__reduce__():
.rename("league")
)
df_equals(result_md, result_pd)


@pytest.mark.parametrize("op", [
"add",
"radd",
"divmod",
"eq",
"floordiv",
"ge",
"gt",
"le",
"lt",
"mod",
"mul",
"rmul",
"ne",
"pow",
"rdivmod",
"rfloordiv",
"rmod",
"rpow",
"rsub",
"rtruediv",
"sub",
"truediv",
])
def test_binary_with_fill_value_issue_7381(op):
# Ensures that series binary operations respect the fill_value flag
series_md, series_pd = create_test_series([0, 1, 2, 3])
rhs_md, rhs_pd = create_test_series([0])
result_md = getattr(series_md, op)(rhs_md, fill_value=2)
result_pd = getattr(series_pd, op)(rhs_pd, fill_value=2)
df_equals(result_md, result_pd)

0 comments on commit abfd048

Please sign in to comment.