Skip to content

Commit

Permalink
Apply changes from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Apr 4, 2024
1 parent 97bad86 commit 7910d48
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 42 deletions.
66 changes: 29 additions & 37 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4673,6 +4673,9 @@ def case_when(self, caselist):
"""
Replace values where the conditions are True.
This is Series.case_when() implementation and, thus, it's designed to work
only with single-column DataFrames.
Parameters
----------
caselist : list of tuples
Expand All @@ -4681,38 +4684,30 @@ def case_when(self, caselist):
-------
PandasDataframe
"""
# For Dask the callables must be wrapped for each partition, otherwise
# the execution could fail with CancelledError.
single_wrap = Engine.get() != "Dask"
cls = type(self)
wrapper_put = self._partition_mgr_cls._execution_wrapper.put
if (
not single_wrap
or (remote_fn := getattr(cls, "_CASE_WHEN_FN", None)) is None
):
# The import is here to avoid an incorrect module initialization when running tests.
# This module is loaded before `pytest_configure()` is called. If `pytest_configure()`
# changes the engine, the `remote_function` decorator will not be valid.
from modin.core.execution.utils import remote_function

def case_when(df, name, caselist): # pragma: no cover
caselist = [
tuple(
(
data.squeeze(axis=1)
if isinstance(data, pandas.DataFrame)
else data
)
for data in case_tuple
@remote_function
def remote_fn(df, name, caselist): # pragma: no cover
caselist = [
tuple(
(
data.squeeze(axis=1)
if isinstance(data, pandas.DataFrame)
else data
)
for case_tuple in caselist
]
return pandas.DataFrame({name: df.squeeze(axis=1).case_when(caselist)})

if single_wrap:
cls._CASE_WHEN_FN = remote_fn = wrapper_put(case_when)
else:
remote_fn = case_when
for data in case_tuple
)
for case_tuple in caselist
]
return pandas.DataFrame({name: df.squeeze(axis=1).case_when(caselist)})

name = self.columns[0]
use_map = single_wrap
cls = type(self)
use_map = True
is_trivial_idx = None
name = self.columns[0]
# Lists of modin frames: first for conditions, second for replacements
modin_lists = [[], []]
# Fill values for conditions and replacements respectively
Expand All @@ -4726,8 +4721,7 @@ def case_when(df, name, caselist): # pragma: no cover
if isinstance(data, cls):
modin_list.append(data)
elif callable(data):
if single_wrap:
data = wrapper_put(data)
data = remote_function(data)
elif isinstance(data, pandas.Series):
use_map = False
if is_trivial_idx is None:
Expand All @@ -4739,7 +4733,8 @@ def case_when(df, name, caselist): # pragma: no cover
diff = length - len(data)
if diff > 0:
data = pandas.concat(
[data, pandas.Series([fill_value] * diff)]
[data, pandas.Series([fill_value] * diff)],
ignore_index=True,
)
else:
data = data.reindex(self_idx, fill_value=fill_value)

Check warning on line 4740 in modin/core/dataframe/pandas/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

modin/core/dataframe/pandas/dataframe/dataframe.py#L4740

Added line #L4740 was not covered by tests
Expand Down Expand Up @@ -4802,23 +4797,20 @@ def map_data(
return data._partitions[part_idx][0]._data
if isinstance(data, pandas.Series):
return data[data_offset : data_offset + part_len]
# As mentioned above, this is required for Dask
if not single_wrap and callable(data):
return wrapper_put(data)
return (
data[data_offset : data_offset + part_len]
if is_list_like(data)
else data
)

parts = [p[0] for p in self._partitions]
lengths = self._get_lengths(parts, Axis.ROW_WISE)
lengths = self.row_lengths
new_parts = []
data_offset = 0

# Split the data and apply the remote function to each partition
# with the corresponding chunk of data
for i, part, part_len in zip(range(0, len(parts)), parts, lengths):
for i, part, part_len in zip(range(len(parts)), parts, lengths):
cases = [
tuple(
map_data(i, part_len, data, data_offset, fill_value)
Expand All @@ -4828,7 +4820,7 @@ def map_data(
]
new_parts.append(
part.add_to_apply_calls(
remote_fn if single_wrap else wrapper_put(remote_fn),
remote_fn,
name,
cases,
length=part_len,
Expand Down
2 changes: 1 addition & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6716,7 +6716,7 @@ def case_when(self, caselist): # noqa: PR01, RT01, D200
Replace values where the conditions are True.
"""
# A workaround for https://github.com/modin-project/modin/issues/7041
qc_type = BaseQueryCompiler
qc_type = type(self)
caselist = [
tuple(
data.to_pandas().squeeze(axis=1) if isinstance(data, qc_type) else data
Expand Down
2 changes: 1 addition & 1 deletion modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4502,7 +4502,7 @@ def compare(self, other, **kwargs):
)

def case_when(self, caselist):
qc_type = BaseQueryCompiler
qc_type = type(self)
caselist = [
tuple(
data._modin_frame if isinstance(data, qc_type) else data
Expand Down
14 changes: 11 additions & 3 deletions modin/pandas/test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pandas.errors import SpecificationError

import modin.pandas as pd
from modin.config import NPartitions, StorageFormat
from modin.config import Engine, NPartitions, StorageFormat
from modin.pandas.io import to_pandas
from modin.pandas.testing import assert_series_equal
from modin.test.test_utils import warns_that_defaulting_to_pandas
Expand Down Expand Up @@ -4723,13 +4723,21 @@ def permutations(values):
"caselist",
_case_when_caselists(),
)
@pytest.mark.skipif(
Engine.get() == "Dask",
reason="https://github.com/modin-project/modin/issues/7148",
)
def test_case_when(base, caselist):
pandas_result = base.case_when(caselist)
modin_bases = [pd.Series(base)]

# 'base' and serieses from 'caselist' must have equal lengths, however in this test we want
# to verify that 'case_when' works correctly even if partitioning of 'base' and 'caselist' isn't equal
if StorageFormat.get() != "Hdk": # HDK always uses a single partition.
# to verify that 'case_when' works correctly even if partitioning of 'base' and 'caselist' isn't equal.
# HDK and BaseOnPython always use a single partition, thus skipping this test for them.
if (
StorageFormat.get() != "Hdk"
and f"{StorageFormat.get()}On{Engine.get()}" != "BaseOnPython"
):
modin_base_repart = construct_modin_df_by_scheme(
base.to_frame(),
partitioning_scheme={"row_lengths": [14, 14, 12], "column_widths": [1]},
Expand Down

0 comments on commit 7910d48

Please sign in to comment.