Skip to content

Commit

Permalink
Revert back to nnodes instead of min_replicas, max_replicas, replicas
Browse files Browse the repository at this point in the history
  • Loading branch information
fg91 committed Apr 27, 2023
1 parent da8a94c commit 5552413
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import flytekit
from flytekit import PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins, IgnoreOutputs
from flytekit.extend import IgnoreOutputs, TaskPlugins

TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`."


@dataclass
Expand All @@ -38,23 +40,18 @@ class Elastic(object):
Use this to run single- or multi-node distributed pytorch elastic training on k8s.
Single-node elastic training is executed in a k8s pod when `replicas` is set to 1.
Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1.
Multi-node training is executed otherwise using a `Pytorch Job <https://github.com/kubeflow/training-operator>`_.
Args:
replicas int: Number of nodes
min_replicas int: Lower limit for the number of replicas to which the training job can scale down
max_replicas int: Upper limit for the number of replicas to which the training job can scale up.
Cannot be smaller than min_replicas.
nnodes (Union[int, str]): Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.
nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int].
start_method (str): Multiprocessing start method to use when creating workers.
monitor_interval (int): Interval, in seconds, to monitor the state of workers.
max_restarts (int): Maximum number of worker group restarts before failing.
"""

replicas: int = 1
min_replicas: Optional[int] = None
max_replicas: Optional[int] = None
nnodes: Union[int, str] = 1
nproc_per_node: Union[int, str] = "auto"
start_method: str = "spawn"
monitor_interval: int = 5
Expand Down Expand Up @@ -116,19 +113,19 @@ class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
_ELASTIC_TASK_TYPE_STANDALONE = "python-task"

def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.replicas == 1 else self._ELASTIC_TASK_TYPE
task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE

super(PytorchElasticFunctionTask, self).__init__(
task_config=task_config,
task_type=task_type,
task_function=task_function,
**kwargs,
)
self.min_replicas = self.task_config.min_replicas or self.task_config.replicas
self.max_replicas = self.task_config.max_replicas or self.task_config.replicas

if not (self.min_replicas <= self.task_config.replicas <= self.max_replicas):
raise ValueError("Replica config violates `min_replicas <= replicas <= max_replicas`.")
try:
from torch.distributed import run
except ImportError:
raise ImportError(TORCH_IMPORT_ERROR_MESSAGE)
self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.task_config.nnodes))

"""
c10d is the backend recommended by torch elastic.
Expand All @@ -152,7 +149,7 @@ def _execute(self, **kwargs) -> Any:
from torch.distributed import run
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
except ImportError:
raise ImportError("PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`.")
raise ImportError(TORCH_IMPORT_ERROR_MESSAGE)

if isinstance(self.task_config.nproc_per_node, str):
nproc = run.determine_local_world_size(self.task_config.nproc_per_node)
Expand All @@ -161,8 +158,8 @@ def _execute(self, **kwargs) -> Any:

config = LaunchConfig(
run_id=flytekit.current_context().execution_id.name,
min_nodes=self.min_replicas,
max_nodes=self.max_replicas,
min_nodes=self.min_nodes,
max_nodes=self.max_nodes,
nproc_per_node=nproc,
rdzv_backend=self.rdzv_backend, # rdzv settings
rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"),
Expand Down Expand Up @@ -226,7 +223,7 @@ def execute(self, **kwargs) -> Any:
return self.dynamic_execute(self._execute, **kwargs)

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
if self.task_config.replicas == 1:
if self.task_config.nnodes == 1:
"""
Torch elastic distributed training is executed in a normal k8s pod so that this
works without the kubeflow train operator.
Expand All @@ -235,13 +232,13 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
else:
elastic_config = ElasticConfig(
rdzv_backend=self.rdzv_backend,
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
min_replicas=self.min_nodes,
max_replicas=self.max_nodes,
nproc_per_node=self.task_config.nproc_per_node,
max_restarts=self.task_config.max_restarts,
)
job = DistributedPyTorchTrainingTask(
workers=self.task_config.replicas,
workers=self.max_nodes,
elastic_config=elastic_config,
)
return MessageToDict(job)
Expand Down

0 comments on commit 5552413

Please sign in to comment.