Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch elastic training (torchrun) #1603

Merged
merged 33 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9527795
Add torch elastic task type
Apr 5, 2023
c0616a8
updated
kumare3 Apr 7, 2023
c9edd3d
updated
kumare3 Apr 8, 2023
153e0d6
Don't pass min_nodes, max_nodes but nnodes only
Apr 8, 2023
264b6c6
Add docstrings
Apr 8, 2023
1a68e44
Cleanup test
Apr 8, 2023
4e73fb4
Replace wrong occurences of mpi in torch-elastic plugin
Apr 8, 2023
da17096
Extend kfpytorch plugin README
Apr 10, 2023
2073b0a
Move elastic task into existing kf-pytorch plugin
Apr 10, 2023
40f8f46
Add Elastic config to plugin's __init__
Apr 10, 2023
90c966e
Set elastic config in pytorchjob proto
Apr 10, 2023
d9567b5
removed unnecessary model files and simplified codebase
kumare3 Apr 11, 2023
fd2dbaf
Configure rdzv endpoint
Apr 11, 2023
e6a74ba
Configure worker count also for elastic training
Apr 11, 2023
007ab27
Fix exception scope and handle non-rank-0 outputs
fg91 Apr 16, 2023
a5b792f
Remove todo item about handling exception scope, now done
fg91 Apr 16, 2023
252e9d9
Add note about c10d backend
fg91 Apr 16, 2023
daa387f
Let user set min and max replicas explicitly and not via nnodes
fg91 Apr 22, 2023
5664aae
Catch torch import error and configure for flytekitplugins-kfpytorch…
fg91 Apr 22, 2023
d45a23e
Add more tests
fg91 Apr 22, 2023
2864d3b
Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
fg91 Apr 22, 2023
b991e90
Lint
fg91 Apr 23, 2023
f0d8f6e
Explicitly add flyteidl version to plugin
kumare3 Apr 24, 2023
16bc6e5
changing import
kumare3 Apr 25, 2023
77fceb0
removed dynamic execute
kumare3 Apr 25, 2023
c96fa30
Revert "removed dynamic execute"
kumare3 Apr 26, 2023
a31fb47
updated ignoreoutputs
kumare3 Apr 26, 2023
da8a94c
requirements rebuilt
kumare3 Apr 26, 2023
28ade30
Revert back to nnodes instead of min_replicas, max_replicas, replicas
fg91 Apr 27, 2023
27fcb29
Don't handle dynamic execution scope
fg91 Apr 27, 2023
00aeb3c
Amend test to new nnodes api
fg91 Apr 27, 2023
312a436
Merge branch 'master' into fabio/feat/torch-elastic-plugin-fix
kumare3 May 3, 2023
a68f8c4
updated types
kumare3 May 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugins/flytekit-kf-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends.

This plugin can execute torch elastic training, which is equivalent to run `torchrun`. Elastic training can be executed
in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner.

To install the plugin, run the following command:

```bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
PyTorch
"""

from .task import PyTorch
from .task import Elastic, PyTorch
23 changes: 0 additions & 23 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py

This file was deleted.

196 changes: 190 additions & 6 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on
Kubernetes. It leverages `Pytorch Job <https://github.com/kubeflow/pytorch-operator>`_ Plugin from kubeflow.
"""
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Union

import cloudpickle
from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask, ElasticConfig
from google.protobuf.json_format import MessageToDict

import flytekit
from flytekit import PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins

from .models import PyTorchJob
from flytekit.extend import IgnoreOutputs, TaskPlugins


@dataclass
Expand All @@ -29,6 +31,36 @@ class PyTorch(object):
num_workers: int


@dataclass
class Elastic(object):
"""
Configuration for `torch elastic training <https://pytorch.org/docs/stable/elastic/run.html>`_.

Use this to run single- or multi-node distributed pytorch elastic training on k8s.

Single-node elastic training is executed in a k8s pod when `replicas` is set to 1.
Multi-node training is executed otherwise using a `Pytorch Job <https://github.com/kubeflow/training-operator>`_.

Args:
replicas int: Number of nodes
min_replicas int: Lower limit for the number of replicas to which the training job can scale down
max_replicas int: Upper limit for the number of replicas to which the training job can scale up.
Cannot be smaller than min_replicas.
nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int].
start_method (str): Multiprocessing start method to use when creating workers.
monitor_interval (int): Interval, in seconds, to monitor the state of workers.
max_restarts (int): Maximum number of worker group restarts before failing.
"""

replicas: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just get rid of replicas?
and just make num_workers = max_replicas?
is that not right?
is replicas supposed to be desired_replicas?

Copy link
Member Author

@fg91 fg91 Apr 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just get rid of replicas?
and just make num_workers = max_replicas?

Yes, we can get this behaviour by reverting daa387f.

I brought this up here because I was very unsure whether we should expose pytorch's nnodes to the user or Kubeflow's minReplicas, maxReplicas, replicas.

is replicas supposed to be desired_replicas?

Kubeflow unfortunately provides zero docs about this. This is how they set defaults for minRepliacs and maxReplicas.

They have two elastic examples. In the first one, they specify:

kind: PyTorchJob
  ...
spec:
  elasticPolicy:
    minReplicas: 1
    maxReplicas: 2
    ...
  pytorchReplicaSpecs:
    Worker:
      replicas: 2

This is basically what you proposed and what we had before daa387f:

and just make num_workers = max_replicas?

In the second example, however, maxReplicas > replicas:

kind: PyTorchJob
  ...
spec:
  elasticPolicy:
    minReplicas: 1
    maxReplicas: 3
    ...
    metrics:
      - type: Resource
        resource:
          name: cpu
          target:
            type: Utilization
            averageUtilization: 80
  pytorchReplicaSpecs:
    Worker:
      replicas: 2

They use HPA to scale up nodes based on the configured metrics.

But I now realize that maxReplicas>replicas only makes sense if you use HPA, otherwise there is no way that you get more than the initial replicas, do you agree?

That being said, using metrics/HPA to dynamically scale up nodes, from everything I can see in the the torch docs, appears to be a Kubeflow design decision and does not appear to come from torchrun itself. The pytorch docs only talk about Membership changes in general since they don't take care of provisioning nodes at all anyways:

Node departure (scale-down): The agent is notified of the departure, all existing workers are stopped, a new WorkerGroup is formed, and all workers are started with a new RANK and WORLD_SIZE.

Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, a new WorkerGroup is formed, and all workers are started with a new RANK and WORLD_SIZE.


My personal opinion:

Being able to continue training if one node dies is a nice feature.

I personally wouldn't want to use e.g. cpu metrics + HPA to add a new distributed training worker: Since HPA might scale nodes up and down all the time, the worker groups might end up getting restarted (on every membership change) far more often than necessary. And training code often needs a bit of time for setting things up, loading the model, ... before the actual training starts. So I'm not convinced of this feature...

To summarize:

If, after digging into this again, I now understand correctly, we would have to expose the metrics as well if we want to allow our users to set max_replicas>replicas. I suggest to not do this, at least in version 1, and revert back to only exposing nnodes.

What do you think @kumare3 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fg91 I seem to agree. but I was thinking instead of nnodes keep min_replica and max_replica? Why not? though I will go with whatever you decide.
I actually am not really sure if HPA works with gloo correctly. Its not really easy to change the membership and work correctly. lets reserve this for later. Lets get the basic version out. If you make the change I can +1 tonight and merge. I have been testing on single node and it seems to work great.

Copy link
Member Author

@fg91 fg91 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a slight preference for nnodes=1 / "1:2" over min_replicas, max_replicas for the following reasons:

  • Not introduce a 3rd way in addition to torchrun's nnodes and Kubeflow's replicas, minReplicas, maxReplicas that neither user knows
  • I think most users will want to have a fixed number of workers but use torchrun simply because either a project like Alpacca or a library like ignite assumes it or because they want to do distributed training on a single node without operator. It should be most simple for those users who don't think about min/max and I feel that nnodes=2 is simpler for them than min_replicas=2, max_replicas=2. (+ if they come from pytorch, they already know the syntax.)

If it's ok for you I will change back to nnodes and ping you again for review :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min_replicas: Optional[int] = None
max_replicas: Optional[int] = None
nproc_per_node: Union[int, str] = "auto"
start_method: str = "spawn"
monitor_interval: int = 5
max_restarts: int = 0


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
"""
Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
Expand All @@ -46,9 +78,161 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = PyTorchJob(workers_count=self.task_config.num_workers)
return MessageToDict(job.to_flyte_idl())
job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers)
return MessageToDict(job)


# Register the Pytorch Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)


def spawn_helper(fn: bytes, kwargs) -> Any:
"""Help to spawn worker processes.

The purpose of this function is to 1) be pickleable so that it can be used with
the multiprocessing start method `spawn` and 2) to call a cloudpickle-serialized
function passed to it. This function itself doesn't have to be pickleable. Without
such a helper task functions, which are not pickleable, couldn't be used with the
start method `spawn`.

Args:
fn (bytes): Cloudpickle-serialized target function to be executed in the worker process.

Returns:
The return value of the received target function.
"""
fn = cloudpickle.loads(fn)
return_val = fn(**kwargs)
return return_val


class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
"""
Plugin for distributed training with torch elastic/torchrun (see
https://pytorch.org/docs/stable/elastic/run.html).
"""

_ELASTIC_TASK_TYPE = "pytorch"
_ELASTIC_TASK_TYPE_STANDALONE = "python-task"

def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):
task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.replicas == 1 else self._ELASTIC_TASK_TYPE

super(PytorchElasticFunctionTask, self).__init__(
task_config=task_config,
task_type=task_type,
task_function=task_function,
**kwargs,
)
self.min_replicas = self.task_config.min_replicas or self.task_config.replicas
self.max_replicas = self.task_config.max_replicas or self.task_config.replicas

if not (self.min_replicas <= self.task_config.replicas <= self.max_replicas):
raise ValueError("Replica config violates `min_replicas <= replicas <= max_replicas`.")

"""
c10d is the backend recommended by torch elastic.
https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend

For c10d, no backend server has to be deployed.
https://pytorch.org/docs/stable/elastic/run.html#deployment
Instead, the workers will use the master's address as the rendezvous point.
"""
self.rdzv_backend = "c10d"

def execute(self, **kwargs) -> Any:
"""
This helper method will be invoked to execute the task.


Returns:
The result of rank zero.
"""
try:
from torch.distributed import run
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
except ImportError:
raise ImportError("PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`.")

if isinstance(self.task_config.nproc_per_node, str):
nproc = run.determine_local_world_size(self.task_config.nproc_per_node)
else:
nproc = self.task_config.nproc_per_node

config = LaunchConfig(
run_id=flytekit.current_context().execution_id.name,
min_nodes=self.min_replicas,
max_nodes=self.max_replicas,
nproc_per_node=nproc,
rdzv_backend=self.rdzv_backend, # rdzv settings
rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"),
max_restarts=self.task_config.max_restarts,
monitor_interval=self.task_config.monitor_interval,
start_method=self.task_config.start_method,
)

if self.task_config.start_method == "spawn":
"""
We use cloudpickle to serialize the non-pickleable task function.
The torch elastic launcher then launches the spawn_helper function (which is pickleable)
instead of the task function. This helper function, in the child-process, then deserializes
the task function, again with cloudpickle, and executes it.
"""
launcher_target_func = spawn_helper

dumped_target_function = cloudpickle.dumps(self._task_function)
launcher_args = (dumped_target_function, kwargs)
elif self.task_config.start_method == "fork":
"""
The torch elastic launcher doesn't support passing kwargs to the target function,
only args. Flyte only works with kwargs. Thus, we create a closure which already has
the task kwargs bound. We tell the torch elastic launcher to start this function in
the child processes.
"""

def fn_partial():
"""Closure of the task function with kwargs already bound."""
return self._task_function(**kwargs)

launcher_target_func = fn_partial
launcher_args = ()

else:
raise Exception("Bad start method")

out = elastic_launch(
config=config,
entrypoint=launcher_target_func,
)(*launcher_args)

# `out` is a dictionary of rank (not local rank) -> result
# Rank 0 returns the result of the task function
if 0 in out:
return out[0]
else:
raise IgnoreOutputs()

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
if self.task_config.replicas == 1:
"""
Torch elastic distributed training is executed in a normal k8s pod so that this
works without the kubeflow train operator.
"""
return super().get_custom(settings)
else:
elastic_config = ElasticConfig(
rdzv_backend=self.rdzv_backend,
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
nproc_per_node=self.task_config.nproc_per_node,
max_restarts=self.task_config.max_restarts,
)
job = DistributedPyTorchTrainingTask(
workers=self.task_config.replicas,
elastic_config=elastic_config,
)
return MessageToDict(job)


# Register the PytorchElastic Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask)
5 changes: 4 additions & 1 deletion plugins/flytekit-kf-pytorch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"]
plugin_requires = ["cloudpickle", "flytekit>=1.3.0,<2.0.0", "flyteidl>=1.3.19"]

__version__ = "0.0.0+develop"

Expand All @@ -17,6 +17,9 @@
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
extras_require={
"elastic": ["torch>=1.9.0"],
},
license="apache2",
python_requires=">=3.8",
classifiers=[
Expand Down
73 changes: 73 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
from dataclasses import dataclass

import pytest
import torch
import torch.distributed as dist
from dataclasses_json import dataclass_json
from flytekitplugins.kfpytorch.task import Elastic

from flytekit import task, workflow


@dataclass_json
@dataclass
class Config:
lr: float = 1e-5
bs: int = 64
name: str = "foo"


def dist_communicate() -> int:
"""Communicate between distributed workers."""
rank = torch.distributed.get_rank()
world_size = dist.get_world_size()
tensor = torch.tensor([5], dtype=torch.int64) + 2 * rank + world_size
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

return tensor.item()


def train(config: Config) -> tuple[str, Config, torch.nn.Module, int]:
"""Mock training a model using torch-elastic for test purposes."""
dist.init_process_group(backend="gloo")

local_rank = os.environ["LOCAL_RANK"]

out_model = torch.nn.Linear(1000, int(local_rank) + 1)
config.name = "elastic-test"

distributed_result = dist_communicate()

return f"result from local rank {local_rank}", config, out_model, distributed_result


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_end_to_end(start_method: str) -> None:
"""Test that the workflow with elastic task runs end to end."""
world_size = 2

train_task = task(train, task_config=Elastic(replicas=1, nproc_per_node=world_size, start_method=start_method))

@workflow
def wf(config: Config = Config()) -> tuple[str, Config, torch.nn.Module, int]:
return train_task(config=config)

r, cfg, m, distributed_result = wf()
assert "result from local rank 0" in r
assert cfg.name == "elastic-test"
assert m.in_features == 1000
assert m.out_features == 1
"""
The distributed result is calculated by the workers of the elastic train
task by performing a `dist.all_reduce` operation. The correct result can
only be obtained if the distributed process group is initialized correctly.
"""
assert distributed_result == sum([5 + 2 * rank + world_size for rank in range(world_size)])


def test_bad_replica_config() -> None:
"""Test that bad replica config is caught."""

with pytest.raises(ValueError):
task(train, task_config=Elastic(replicas=1, min_replicas=2))