diff --git a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py index 65f2704d054..71058d97921 100644 --- a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py +++ b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py @@ -26,13 +26,6 @@ from modin.utils import _inherit_docstrings -# If Ray has not been initialized yet by Modin, -# it will be initialized when calling `RayWrapper.put`. -_DEPLOY_AXIS_FUNC = RayWrapper.put(PandasDataframeAxisPartition.deploy_axis_func) -_DEPLOY_SPLIT_FUNC = RayWrapper.put(PandasDataframeAxisPartition.deploy_splitting_func) -_DRAIN = RayWrapper.put(PandasDataframeAxisPartition.drain) - - class PandasOnRayDataframeVirtualPartition(PandasDataframeAxisPartition): """ The class implements the interface in ``PandasDataframeAxisPartition``. @@ -60,6 +53,32 @@ class PandasOnRayDataframeVirtualPartition(PandasDataframeAxisPartition): instance_type = ray.ObjectRef axis = None + _DEPLOY_AXIS_FUNC = None + _DEPLOY_SPLIT_FUNC = None + _DRAIN_FUNC = None + + @classmethod + def _get_deploy_axis_func(cls): + if cls._DEPLOY_AXIS_FUNC is None: + cls._DEPLOY_AXIS_FUNC = RayWrapper.put( + PandasDataframeAxisPartition.deploy_axis_func + ) + return cls._DEPLOY_AXIS_FUNC + + @classmethod + def _get_deploy_split_func(cls): + if cls._DEPLOY_SPLIT_FUNC is None: + cls._DEPLOY_SPLIT_FUNC = RayWrapper.put( + PandasDataframeAxisPartition.deploy_splitting_func + ) + return cls._DEPLOY_SPLIT_FUNC + + @classmethod + def _get_drain_func(cls): + if cls._DRAIN_FUNC is None: + cls._DRAIN_FUNC = RayWrapper.put(PandasDataframeAxisPartition.drain) + return cls._DRAIN_FUNC + def __init__( self, list_of_partitions, @@ -169,7 +188,7 @@ def deploy_splitting_func( if extract_metadata else num_splits, ).remote( - _DEPLOY_SPLIT_FUNC, + cls._DEPLOY_SPLIT_FUNC, axis, func, f_args, @@ -230,7 +249,7 @@ def deploy_axis_func( * (1 + cls._PARTITIONS_METADATA_LEN), **({"max_retries": max_retries} if max_retries is not None else {}), ).remote( - _DEPLOY_AXIS_FUNC, + cls._DEPLOY_AXIS_FUNC, axis, func, f_args, @@ -486,7 +505,7 @@ def drain_call_queue(self, num_splits=None): _ = self.list_of_blocks return drained = super(PandasOnRayDataframeVirtualPartition, self).apply( - _DRAIN, num_splits=num_splits, call_queue=self.call_queue + self._DRAIN, num_splits=num_splits, call_queue=self.call_queue ) self._list_of_block_partitions = drained self.call_queue = []