Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch elastic training (torchrun) #1603

Merged
merged 33 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9527795
Add torch elastic task type
Apr 5, 2023
c0616a8
updated
kumare3 Apr 7, 2023
c9edd3d
updated
kumare3 Apr 8, 2023
153e0d6
Don't pass min_nodes, max_nodes but nnodes only
Apr 8, 2023
264b6c6
Add docstrings
Apr 8, 2023
1a68e44
Cleanup test
Apr 8, 2023
4e73fb4
Replace wrong occurences of mpi in torch-elastic plugin
Apr 8, 2023
da17096
Extend kfpytorch plugin README
Apr 10, 2023
2073b0a
Move elastic task into existing kf-pytorch plugin
Apr 10, 2023
40f8f46
Add Elastic config to plugin's __init__
Apr 10, 2023
90c966e
Set elastic config in pytorchjob proto
Apr 10, 2023
d9567b5
removed unnecessary model files and simplified codebase
kumare3 Apr 11, 2023
fd2dbaf
Configure rdzv endpoint
Apr 11, 2023
e6a74ba
Configure worker count also for elastic training
Apr 11, 2023
007ab27
Fix exception scope and handle non-rank-0 outputs
fg91 Apr 16, 2023
a5b792f
Remove todo item about handling exception scope, now done
fg91 Apr 16, 2023
252e9d9
Add note about c10d backend
fg91 Apr 16, 2023
daa387f
Let user set min and max replicas explicitly and not via nnodes
fg91 Apr 22, 2023
5664aae
Catch torch import error and configure for flytekitplugins-kfpytorch…
fg91 Apr 22, 2023
d45a23e
Add more tests
fg91 Apr 22, 2023
2864d3b
Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
fg91 Apr 22, 2023
b991e90
Lint
fg91 Apr 23, 2023
f0d8f6e
Explicitly add flyteidl version to plugin
kumare3 Apr 24, 2023
16bc6e5
changing import
kumare3 Apr 25, 2023
77fceb0
removed dynamic execute
kumare3 Apr 25, 2023
c96fa30
Revert "removed dynamic execute"
kumare3 Apr 26, 2023
a31fb47
updated ignoreoutputs
kumare3 Apr 26, 2023
da8a94c
requirements rebuilt
kumare3 Apr 26, 2023
28ade30
Revert back to nnodes instead of min_replicas, max_replicas, replicas
fg91 Apr 27, 2023
27fcb29
Don't handle dynamic execution scope
fg91 Apr 27, 2023
00aeb3c
Amend test to new nnodes api
fg91 Apr 27, 2023
312a436
Merge branch 'master' into fabio/feat/torch-elastic-plugin-fix
kumare3 May 3, 2023
a68f8c4
updated types
kumare3 May 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugins/flytekit-kf-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends.

This plugin can execute torch elastic training, which is equivalent to run `torchrun`. Elastic training can be executed
in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner.

To install the plugin, run the following command:

```bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
PyTorch
"""

from .task import PyTorch
from .task import Elastic, PyTorch
23 changes: 0 additions & 23 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py

This file was deleted.

201 changes: 196 additions & 5 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on
Kubernetes. It leverages `Pytorch Job <https://github.com/kubeflow/pytorch-operator>`_ Plugin from kubeflow.
"""
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Union

import cloudpickle
from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask, ElasticConfig
from google.protobuf.json_format import MessageToDict

import flytekit
from flytekit import PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins
from flytekit.extend import IgnoreOutputs, TaskPlugins

from .models import PyTorchJob
TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`."


@dataclass
Expand All @@ -29,6 +33,31 @@ class PyTorch(object):
num_workers: int


@dataclass
class Elastic(object):
"""
Configuration for `torch elastic training <https://pytorch.org/docs/stable/elastic/run.html>`_.

Use this to run single- or multi-node distributed pytorch elastic training on k8s.

Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1.
Multi-node training is executed otherwise using a `Pytorch Job <https://github.com/kubeflow/training-operator>`_.

Args:
nnodes (Union[int, str]): Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.
nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int].
start_method (str): Multiprocessing start method to use when creating workers.
monitor_interval (int): Interval, in seconds, to monitor the state of workers.
max_restarts (int): Maximum number of worker group restarts before failing.
"""

nnodes: Union[int, str] = 1
nproc_per_node: Union[int, str] = "auto"
start_method: str = "spawn"
monitor_interval: int = 5
max_restarts: int = 0


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
"""
Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
Expand All @@ -46,9 +75,171 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = PyTorchJob(workers_count=self.task_config.num_workers)
return MessageToDict(job.to_flyte_idl())
job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers)
return MessageToDict(job)


# Register the Pytorch Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)


def spawn_helper(fn: bytes, kwargs) -> Any:
"""Help to spawn worker processes.

The purpose of this function is to 1) be pickleable so that it can be used with
the multiprocessing start method `spawn` and 2) to call a cloudpickle-serialized
function passed to it. This function itself doesn't have to be pickleable. Without
such a helper task functions, which are not pickleable, couldn't be used with the
start method `spawn`.

Args:
fn (bytes): Cloudpickle-serialized target function to be executed in the worker process.

Returns:
The return value of the received target function.
"""
fn = cloudpickle.loads(fn)
return_val = fn(**kwargs)
return return_val


class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
"""
Plugin for distributed training with torch elastic/torchrun (see
https://pytorch.org/docs/stable/elastic/run.html).
"""

_ELASTIC_TASK_TYPE = "pytorch"
_ELASTIC_TASK_TYPE_STANDALONE = "python-task"

def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE

super(PytorchElasticFunctionTask, self).__init__(
task_config=task_config,
task_type=task_type,
task_function=task_function,
**kwargs,
)
try:
from torch.distributed import run
except ImportError:
raise ImportError(TORCH_IMPORT_ERROR_MESSAGE)
self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.task_config.nnodes))

"""
c10d is the backend recommended by torch elastic.
https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend

For c10d, no backend server has to be deployed.
https://pytorch.org/docs/stable/elastic/run.html#deployment
Instead, the workers will use the master's address as the rendezvous point.
"""
self.rdzv_backend = "c10d"

def _execute(self, **kwargs) -> Any:
"""
This helper method will be invoked to execute the task.


Returns:
The result of rank zero.
"""
try:
from torch.distributed import run
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
except ImportError:
raise ImportError(TORCH_IMPORT_ERROR_MESSAGE)

if isinstance(self.task_config.nproc_per_node, str):
nproc = run.determine_local_world_size(self.task_config.nproc_per_node)
else:
nproc = self.task_config.nproc_per_node

config = LaunchConfig(
run_id=flytekit.current_context().execution_id.name,
min_nodes=self.min_nodes,
max_nodes=self.max_nodes,
nproc_per_node=nproc,
rdzv_backend=self.rdzv_backend, # rdzv settings
rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"),
max_restarts=self.task_config.max_restarts,
monitor_interval=self.task_config.monitor_interval,
start_method=self.task_config.start_method,
)

if self.task_config.start_method == "spawn":
"""
We use cloudpickle to serialize the non-pickleable task function.
The torch elastic launcher then launches the spawn_helper function (which is pickleable)
instead of the task function. This helper function, in the child-process, then deserializes
the task function, again with cloudpickle, and executes it.
"""
launcher_target_func = spawn_helper

dumped_target_function = cloudpickle.dumps(self._task_function)
launcher_args = (dumped_target_function, kwargs)
elif self.task_config.start_method == "fork":
"""
The torch elastic launcher doesn't support passing kwargs to the target function,
only args. Flyte only works with kwargs. Thus, we create a closure which already has
the task kwargs bound. We tell the torch elastic launcher to start this function in
the child processes.
"""

def fn_partial():
"""Closure of the task function with kwargs already bound."""
return self._task_function(**kwargs)

launcher_target_func = fn_partial
launcher_args = ()

else:
raise Exception("Bad start method")

out = elastic_launch(
config=config,
entrypoint=launcher_target_func,
)(*launcher_args)

# `out` is a dictionary of rank (not local rank) -> result
# Rank 0 returns the result of the task function
if 0 in out:
return out[0]
else:
raise IgnoreOutputs()

def execute(self, **kwargs) -> Any:
"""
This method will be invoked to execute the task.

Handles the exception scope for the `_execute` method.
"""
from flytekit.exceptions import scopes as exception_scopes

return exception_scopes.user_entry_point(self._execute)(**kwargs)

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
if self.task_config.nnodes == 1:
"""
Torch elastic distributed training is executed in a normal k8s pod so that this
works without the kubeflow train operator.
"""
return super().get_custom(settings)
else:
elastic_config = ElasticConfig(
rdzv_backend=self.rdzv_backend,
min_replicas=self.min_nodes,
max_replicas=self.max_nodes,
nproc_per_node=self.task_config.nproc_per_node,
max_restarts=self.task_config.max_restarts,
)
job = DistributedPyTorchTrainingTask(
workers=self.max_nodes,
elastic_config=elastic_config,
)
return MessageToDict(job)


# Register the PytorchElastic Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask)
Loading