diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index 5be2408bc1a..3712d7fc1fd 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -4673,6 +4673,9 @@ def case_when(self, caselist): """ Replace values where the conditions are True. + This is Series.case_when() implementation and, thus, it's designed to work + only with single-column DataFrames. + Parameters ---------- caselist : list of tuples @@ -4681,38 +4684,30 @@ def case_when(self, caselist): ------- PandasDataframe """ - # For Dask the callables must be wrapped for each partition, otherwise - # the execution could fail with CancelledError. - single_wrap = Engine.get() != "Dask" - cls = type(self) - wrapper_put = self._partition_mgr_cls._execution_wrapper.put - if ( - not single_wrap - or (remote_fn := getattr(cls, "_CASE_WHEN_FN", None)) is None - ): + # The import is here to avoid an incorrect module initialization when running tests. + # This module is loaded before `pytest_configure()` is called. If `pytest_configure()` + # changes the engine, the `remote_function` decorator will not be valid. + from modin.core.execution.utils import remote_function - def case_when(df, name, caselist): # pragma: no cover - caselist = [ - tuple( - ( - data.squeeze(axis=1) - if isinstance(data, pandas.DataFrame) - else data - ) - for data in case_tuple + @remote_function + def remote_fn(df, name, caselist): # pragma: no cover + caselist = [ + tuple( + ( + data.squeeze(axis=1) + if isinstance(data, pandas.DataFrame) + else data ) - for case_tuple in caselist - ] - return pandas.DataFrame({name: df.squeeze(axis=1).case_when(caselist)}) - - if single_wrap: - cls._CASE_WHEN_FN = remote_fn = wrapper_put(case_when) - else: - remote_fn = case_when + for data in case_tuple + ) + for case_tuple in caselist + ] + return pandas.DataFrame({name: df.squeeze(axis=1).case_when(caselist)}) - name = self.columns[0] - use_map = single_wrap + cls = type(self) + use_map = True is_trivial_idx = None + name = self.columns[0] # Lists of modin frames: first for conditions, second for replacements modin_lists = [[], []] # Fill values for conditions and replacements respectively @@ -4726,8 +4721,7 @@ def case_when(df, name, caselist): # pragma: no cover if isinstance(data, cls): modin_list.append(data) elif callable(data): - if single_wrap: - data = wrapper_put(data) + data = remote_function(data) elif isinstance(data, pandas.Series): use_map = False if is_trivial_idx is None: @@ -4739,7 +4733,8 @@ def case_when(df, name, caselist): # pragma: no cover diff = length - len(data) if diff > 0: data = pandas.concat( - [data, pandas.Series([fill_value] * diff)] + [data, pandas.Series([fill_value] * diff)], + ignore_index=True, ) else: data = data.reindex(self_idx, fill_value=fill_value) @@ -4802,9 +4797,6 @@ def map_data( return data._partitions[part_idx][0]._data if isinstance(data, pandas.Series): return data[data_offset : data_offset + part_len] - # As mentioned above, this is required for Dask - if not single_wrap and callable(data): - return wrapper_put(data) return ( data[data_offset : data_offset + part_len] if is_list_like(data) @@ -4812,13 +4804,13 @@ def map_data( ) parts = [p[0] for p in self._partitions] - lengths = self._get_lengths(parts, Axis.ROW_WISE) + lengths = self.row_lengths new_parts = [] data_offset = 0 # Split the data and apply the remote function to each partition # with the corresponding chunk of data - for i, part, part_len in zip(range(0, len(parts)), parts, lengths): + for i, part, part_len in zip(range(len(parts)), parts, lengths): cases = [ tuple( map_data(i, part_len, data, data_offset, fill_value) @@ -4828,7 +4820,7 @@ def map_data( ] new_parts.append( part.add_to_apply_calls( - remote_fn if single_wrap else wrapper_put(remote_fn), + remote_fn, name, cases, length=part_len, diff --git a/modin/core/storage_formats/base/query_compiler.py b/modin/core/storage_formats/base/query_compiler.py index aaa0a4a6a31..116e0bd9502 100644 --- a/modin/core/storage_formats/base/query_compiler.py +++ b/modin/core/storage_formats/base/query_compiler.py @@ -6716,7 +6716,7 @@ def case_when(self, caselist): # noqa: PR01, RT01, D200 Replace values where the conditions are True. """ # A workaround for https://github.com/modin-project/modin/issues/7041 - qc_type = BaseQueryCompiler + qc_type = type(self) caselist = [ tuple( data.to_pandas().squeeze(axis=1) if isinstance(data, qc_type) else data diff --git a/modin/core/storage_formats/pandas/query_compiler.py b/modin/core/storage_formats/pandas/query_compiler.py index 462473b239a..95ff6a33522 100644 --- a/modin/core/storage_formats/pandas/query_compiler.py +++ b/modin/core/storage_formats/pandas/query_compiler.py @@ -4502,7 +4502,7 @@ def compare(self, other, **kwargs): ) def case_when(self, caselist): - qc_type = BaseQueryCompiler + qc_type = type(self) caselist = [ tuple( data._modin_frame if isinstance(data, qc_type) else data diff --git a/modin/pandas/test/test_series.py b/modin/pandas/test/test_series.py index 0686d580873..a78cb5706b4 100644 --- a/modin/pandas/test/test_series.py +++ b/modin/pandas/test/test_series.py @@ -28,7 +28,7 @@ from pandas.errors import SpecificationError import modin.pandas as pd -from modin.config import NPartitions, StorageFormat +from modin.config import Engine, NPartitions, StorageFormat from modin.pandas.io import to_pandas from modin.pandas.testing import assert_series_equal from modin.test.test_utils import warns_that_defaulting_to_pandas @@ -4723,13 +4723,21 @@ def permutations(values): "caselist", _case_when_caselists(), ) +@pytest.mark.skipif( + Engine.get() == "Dask", + reason="https://github.com/modin-project/modin/issues/7148", +) def test_case_when(base, caselist): pandas_result = base.case_when(caselist) modin_bases = [pd.Series(base)] # 'base' and serieses from 'caselist' must have equal lengths, however in this test we want - # to verify that 'case_when' works correctly even if partitioning of 'base' and 'caselist' isn't equal - if StorageFormat.get() != "Hdk": # HDK always uses a single partition. + # to verify that 'case_when' works correctly even if partitioning of 'base' and 'caselist' isn't equal. + # HDK and BaseOnPython always use a single partition, thus skipping this test for them. + if ( + StorageFormat.get() != "Hdk" + and f"{StorageFormat.get()}On{Engine.get()}" != "BaseOnPython" + ): modin_base_repart = construct_modin_df_by_scheme( base.to_frame(), partitioning_scheme={"row_lengths": [14, 14, 12], "column_widths": [1]},