Skip to content

Commit

Permalink
FIX-modin-project#6022: move '_DEPLOY_AXIS_FUNC', '_DEPLOY_SPLIT_FUNC…
Browse files Browse the repository at this point in the history
…', '_DRAIN' into Ray virtual partition

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
  • Loading branch information
anmyachev committed Apr 19, 2023
1 parent bbc7697 commit 6bacdfd
Showing 1 changed file with 29 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
from modin.utils import _inherit_docstrings


# If Ray has not been initialized yet by Modin,
# it will be initialized when calling `RayWrapper.put`.
_DEPLOY_AXIS_FUNC = RayWrapper.put(PandasDataframeAxisPartition.deploy_axis_func)
_DEPLOY_SPLIT_FUNC = RayWrapper.put(PandasDataframeAxisPartition.deploy_splitting_func)
_DRAIN = RayWrapper.put(PandasDataframeAxisPartition.drain)


class PandasOnRayDataframeVirtualPartition(PandasDataframeAxisPartition):
"""
The class implements the interface in ``PandasDataframeAxisPartition``.
Expand Down Expand Up @@ -60,6 +53,32 @@ class PandasOnRayDataframeVirtualPartition(PandasDataframeAxisPartition):
instance_type = ray.ObjectRef
axis = None

_DEPLOY_AXIS_FUNC = None
_DEPLOY_SPLIT_FUNC = None
_DRAIN_FUNC = None

@classmethod
def _get_deploy_axis_func(cls):
if cls._DEPLOY_AXIS_FUNC is None:
cls._DEPLOY_AXIS_FUNC = RayWrapper.put(
PandasDataframeAxisPartition.deploy_axis_func
)
return cls._DEPLOY_AXIS_FUNC

@classmethod
def _get_deploy_split_func(cls):
if cls._DEPLOY_SPLIT_FUNC is None:
cls._DEPLOY_SPLIT_FUNC = RayWrapper.put(
PandasDataframeAxisPartition.deploy_splitting_func
)
return cls._DEPLOY_SPLIT_FUNC

@classmethod
def _get_drain_func(cls):
if cls._DRAIN_FUNC is None:
cls._DRAIN_FUNC = RayWrapper.put(PandasDataframeAxisPartition.drain)
return cls._DRAIN_FUNC

def __init__(
self,
list_of_partitions,
Expand Down Expand Up @@ -169,7 +188,7 @@ def deploy_splitting_func(
if extract_metadata
else num_splits,
).remote(
_DEPLOY_SPLIT_FUNC,
cls._DEPLOY_SPLIT_FUNC,
axis,
func,
f_args,
Expand Down Expand Up @@ -230,7 +249,7 @@ def deploy_axis_func(
* (1 + cls._PARTITIONS_METADATA_LEN),
**({"max_retries": max_retries} if max_retries is not None else {}),
).remote(
_DEPLOY_AXIS_FUNC,
cls._DEPLOY_AXIS_FUNC,
axis,
func,
f_args,
Expand Down Expand Up @@ -486,7 +505,7 @@ def drain_call_queue(self, num_splits=None):
_ = self.list_of_blocks
return
drained = super(PandasOnRayDataframeVirtualPartition, self).apply(
_DRAIN, num_splits=num_splits, call_queue=self.call_queue
self._DRAIN, num_splits=num_splits, call_queue=self.call_queue
)
self._list_of_block_partitions = drained
self.call_queue = []
Expand Down

0 comments on commit 6bacdfd

Please sign in to comment.