Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fleet support elastic scale up/down #36684

Merged
merged 22 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/paddle/distributed/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ def close(self):
parser.add_argument(
"--elastic_server", type=str, help="etcd server host:port")
parser.add_argument("--job_id", type=str, help="job unique id")
parser.add_argument("--np", type=int, help="job pod/node number")
parser.add_argument(
"--np",
type=str,
help="job pod/node number, need to be 'MIN' or 'MIN:MAX' format")
parser.add_argument("action", type=str, help="action to take")

args = parser.parse_args()

server = args.elastic_server or os.getenv('PADDLE_ELASTIC_SERVER')
name = args.job_id or os.getenv('PADDLE_ELASTIC_JOB_ID')

np = args.np or int(os.getenv('PADDLE_ELASTIC_NP', 0))
np = int(args.np.split(":")[0]) or int(os.getenv('PADDLE_ELASTIC_NP', 0))

cmd = Command(server, name)

Expand Down
8 changes: 6 additions & 2 deletions python/paddle/distributed/fleet/elastic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ def enable_elastic(args, distribute_mode):
if not args.job_id and not os.getenv('PADDLE_ELASTIC_JOB_ID'):
return False

if not args.np and not int(os.getenv('PADDLE_ELASTIC_NP', 0)):
if not args.np and not os.getenv('PADDLE_ELASTIC_NP'):
return False

return True


def launch_elastic(args, distribute_mode):

elastic = ElasticManager(args)
server = args.elastic_server or os.getenv('PADDLE_ELASTIC_SERVER')
srv, port = server.split(':')
import etcd3
etcd_client = etcd3.client(host=srv, port=port)
elastic = ElasticManager(args, etcd_client)

signal.signal(signal.SIGTERM, elastic.signal_handler)
signal.signal(signal.SIGABRT, elastic.signal_handler)
Expand Down
201 changes: 168 additions & 33 deletions python/paddle/distributed/fleet/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
logger = logging.getLogger("ELASTIC")

ELASTIC_EXIT_CODE = 101
# unit: seconds
ELASTIC_TIMEOUT = 60


# 1: Fault tolerance, 2: Elastic
class ElasticLevel:
FAULT_TOLERANCE = 1
ELASTIC = 2


class ElasticStatus:
Expand Down Expand Up @@ -101,21 +109,34 @@ def watch(self):


class ElasticManager(object):
def __init__(self, args):
def __init__(self, args, etcd_client):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use elastic_client later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前只有etcd一种backend,后续可考虑提供一种基于rpc的kv-store作为默认的client


self.args = args
server = args.elastic_server or os.getenv('PADDLE_ELASTIC_SERVER')
name = args.job_id or os.getenv('PADDLE_ELASTIC_JOB_ID')
np = args.np or int(os.getenv('PADDLE_ELASTIC_NP', 0))
self.min_np, self.max_np = self._parse_np(args.np)
np = self.min_np
host = args.host or os.getenv('POD_IP')
scale = args.scale or int(os.getenv('PADDLE_ELASTIC_SCALE', 0))
force = args.force or os.getenv('PADDLE_ELASTIC_FORCE')

self.endpoints = os.getenv('DISTRIBUTED_TRAINER_ENDPOINTS', '')
self.trainers = os.getenv('PADDLE_TRAINERS', '')
self.lastest_trainers = self.trainers
logger.info(
f"trainers={self.trainers}, lastest_trainers={self.lastest_trainers}"
)

# auto correct the value of elastic_level
# 1: Fault tolerant, 2: Elastic
self.elastic_level = int(
os.getenv('PADDLE_ELASTIC_FAULT_TOLERANC_LEVEL', 1))
os.getenv('PADDLE_ELASTIC_FAULT_TOLERANC_LEVEL',
ElasticLevel.FAULT_TOLERANCE))
if self.min_np == self.max_np or \
(self.min_np > 0 and self.max_np == 0):
self.elastic_level = ElasticLevel.FAULT_TOLERANCE
if self.min_np > 0 and self.max_np > self.min_np:
self.elastic_level = ElasticLevel.ELASTIC

# compatible with kuberntes service discovery
if not server and os.getenv(
Expand Down Expand Up @@ -144,10 +165,7 @@ def __init__(self, args):
else:
self.enable = True

import etcd3

srv, port = server.split(':')
self.etcd = etcd3.client(host=srv, port=port)
self.etcd = etcd_client
self.host = host if host else self._get_host()

# etcd data
Expand Down Expand Up @@ -222,6 +240,7 @@ def endpoints_call_back(event):
self.watches = [host_watch, np_watch, endpoints_watch]

self.launcher = None
self.elastic_startup_time = None

def exit(self, completed=False):
logger.info('manager exist completed {}'.format(completed))
Expand All @@ -243,6 +262,29 @@ def exit(self, completed=False):
if len(hosts) == 0:
self.etcd.delete_prefix(self.prefix)

def _parse_np(self, np: str):
"""
np format is "MIN" or "MIN:MAX"
"""
np_str = np or os.getenv('PADDLE_ELASTIC_NP', "0")
np_dict = np_str.split(":")
min_np = max_np = 0
if len(np_dict) == 1:
# Fault tolerant
min_np = int(np_dict[0])
min_np = 1 if min_np <= 0 else min_np
elif len(np_dict) == 2:
# Elastic
min_np = int(np_dict[0])
max_np = int(np_dict[1])
min_np = 1 if min_np <= 0 else min_np
max_np = min_np if min_np > max_np else max_np
else:
raise ValueError(
f'the np={np} needs to be in "MIN" or "MIN:MAX" format')

return min_np, max_np

def _get_host(self):
try:
return socket.gethostbyname(socket.getfqdn(socket.gethostname()))
Expand All @@ -255,40 +297,126 @@ def _completed(self):

return int(self.etcd.get(self.prefix)[0]) == 1

def _match(self):
def _match(self, host_list: list=None):

self.hosts = [
six.ensure_str(i[0]) for i in self.etcd.get_prefix(self.node_prefix)
]
if len(self.hosts) == self.np:
return True
if host_list:
self.hosts = host_list
else:
return False
self.hosts = [
six.ensure_str(i[0])
for i in self.etcd.get_prefix(self.node_prefix)
]

if self.elastic_level == ElasticLevel.FAULT_TOLERANCE:
if len(self.hosts) == self.np:
return True
else:
return False

if self.elastic_level == ElasticLevel.ELASTIC:
# FIXME(xym) add freeze status
if not self.elastic_startup_time:
self.elastic_startup_time = time.time()
hosts_num = len(self.hosts)
if hosts_num >= self.min_np and hosts_num <= self.max_np:
interval_time = time.time() - self.elastic_startup_time
if interval_time <= ELASTIC_TIMEOUT:
print(
f"current interval_time={interval_time} hosts_num={hosts_num} reached the min_np={self.min_np}, ELASTIC_TIMEOUT={ELASTIC_TIMEOUT}"
)
logger.info(
f"current interval_time={interval_time} hosts_num={hosts_num} reached the min_np={self.min_np}, wait for timeout"
)
return False

self.elastic_startup_time = time.time()
return True
else:
self.elastic_startup_time = time.time()
return False

def _update_hosts(self):
assert len(self.hosts) != 0, 'hosts empty'
rank = int(os.getenv('PADDLE_TRAINER_ID', -1))
if self.elastic_level == ElasticLevel.FAULT_TOLERANCE:
self.lastest_trainers = self.trainers
if self.host in self.endpoints:
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.endpoints
os.environ['PADDLE_TRAINERS'] = self.trainers
logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ".
format(self.endpoints))
logger.info("update env PADDLE_TRAINERS {} ".format(
self.trainers))
return

if self.host in self.endpoints:
os.environ['DISTRIBUTED_TRAINER_ENDPOINTS'] = self.endpoints
os.environ['PADDLE_TRAINERS'] = self.trainers
logger.info("update env DISTRIBUTED_TRAINER_ENDPOINTS {} ".format(
self.endpoints))
logger.info("update env PADDLE_TRAINERS {} ".format(self.trainers))
return
# fault tolerance
idx = self.hosts.index(self.host)

rank = int(os.getenv('PADDLE_TRAINER_ID', -1))
idx = self.hosts.index(self.host)
# swap if self.host not in the right position
if rank >= 0:
self.hosts[idx] = self.hosts[rank]
self.hosts[rank] = self.host
else:
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx)

# swap if self.host not in the right position
if rank >= 0:
self.hosts[idx] = self.hosts[rank]
self.hosts[rank] = self.host
hosts = ','.join(self.hosts)
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
else:
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(idx)

hosts = ','.join(self.hosts)
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
# elastic, scale up/down
trainers = self.lastest_trainers.split(",")
if len(self.hosts) > len(trainers):
# scale up
logger.info(
f"elastic scale up, hosts={self.hosts}, trainers={trainers}")

for curr_host in self.hosts:
if curr_host not in trainers:
trainers.append(curr_host)
if rank < 0:
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
trainers.index(self.host))
hosts = ','.join(trainers)
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
self.lastest_trainers = hosts
else:
# scale down
logger.info(
f"elastic scale down, hosts={self.hosts}, trainers={trainers}"
)

# If the shrink node is from the first of the rank list, you need to minimize the movement of the rank
# eg:
# the source trainers is:10.10.10.0,10.10.10.1,10.10.10.2,10.10.10.3
# 10.10.10.0 is removed
# the new trainers is:10.10.10.3,10.10.10.1,10.10.10.2
# In this case, the rank of 10.10.10.1 and 10.10.10.2 remains unchanged, while the rank of 10.10.10.3 is set to rank0
hosts_dict = dict()
unsorted_host = []
for id, host in enumerate(self.hosts):
idx = trainers.index(host)
if idx <= len(self.hosts) - 1:
hosts_dict[idx] = host
else:
unsorted_host.append(host)

idle_index = 0
sorted_hosts = []
for idx in range(len(self.hosts)):
if not hosts_dict.get(idx):
hosts_dict[idx] = unsorted_host[idle_index]
idle_index += 1

sorted_hosts.append(hosts_dict.get(idx))

logger.info(f"elastic scale down, sorted_hosts={sorted_hosts}")
hosts = ','.join(sorted_hosts)
self.args.ips = hosts
os.environ['PADDLE_TRAINERS'] = hosts
os.environ['PADDLE_TRAINER_ID'] = '{}'.format(
sorted_hosts.index(self.host))
self.lastest_trainers = hosts

def wait(self):
if not self.enable:
Expand Down Expand Up @@ -335,20 +463,27 @@ def watch(self):
completed = True if ret == 0 else False
self.exit(completed=completed)
if completed:
#logger.info(":watch, job completed")
return ElasticStatus.COMPLETED
if self.elastic_level == 1:
if self.elastic_level == ElasticLevel.FAULT_TOLERANCE or \
self.elastic_level == ElasticLevel.ELASTIC:
#logger.info(":watch, job restart")
return ElasticStatus.RESTART
else:
#logger.info(":watch, job error")
return ElasticStatus.ERROR

if not self._completed() and (not self._match() or self.need_sync):
self.launcher.stop()
#logger.info(":watch, job hold")
return ElasticStatus.HOLD

time.sleep(2)

if self.launcher:
self.launcher.stop()

#logger.info(":watch, job exit")
return ElasticStatus.EXIT

def signal_handler(self, sigint, frame):
Expand Down
48 changes: 48 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fleet_elastic_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import time
import unittest
import argparse
from warnings import catch_warnings

from paddle.distributed.fleet.elastic import enable_elastic, launch_elastic
from paddle.distributed.fleet.launch_utils import DistributeMode


class TestElasticInit(unittest.TestCase):
def setUp(self):
class Argument:
elastic_server = "127.0.0.1:2379"
job_id = "test_job_id_123"
np = "2:4"

self.args = Argument()

def test_enable_elastic(self):
result = enable_elastic(self.args, DistributeMode.COLLECTIVE)
self.assertEqual(result, True)

def test_launch_elastic(self):
try:
launch_elastic(self.args, DistributeMode.COLLECTIVE)
except Exception as e:
pass


if __name__ == "__main__":
unittest.main()
Loading