diff --git a/python/paddle/distributed/run/__init__.py b/python/paddle/distributed/run/__init__.py new file mode 100644 index 0000000000000..f25ddb794cc4d --- /dev/null +++ b/python/paddle/distributed/run/__init__.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .job.container import Container +from .job.pod import Pod +from .job.job import Job +from . import plugins + +#__all__ = [Container, Pod, Job] +''' +Paddle distribution training entry ``python -m paddle.distributed.run``. + +Help + +# for arg usage and explanation, try the following command +# python -m paddle.distributed.run -h + +Collective Mode + +Case 1: 1 node + +use all visible devices +# python -m paddle.distributed.run train.py + +use specified devices +# python -m paddle.distributed.run --devices=0,1,2,3 train.py + +Case 2: multi-node, auto detect ip/port + +# python -m paddle.distributed.run --np 2 train.py +# auto print following command +# python -m paddle.distributed.run --master 10.0.0.1:13538 --np 2 demo.py +# then copy and paste above command to other nodes + +Case 3: multi-node, specified master/rendezvous server + +# python -m paddle.distributed.run --np 2 --master 10.0.0.1:2379 train.py +# the master ip must be one of the node and the port must available + +Parameter Server Mode + +Case 1.1: 1 node, 1 ps, 1 trainer + +# python -m paddle.distributed.run --mode ps train.py +# python -m paddle.distributed.run --server_num=1 --trainer_num=1 train.py + +Case 1.2: 1 node, 2 ps, 2 trainer + +# python -m paddle.distributed.run --server_num=2 --trainer_num=2 train.py + +Case 2: 2 node, 2 ps, 2 trainer per node + +# python -m paddle.distributed.run --server_num=2 --trainer_num=2 --np 2 train.py +# auto print following command +# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py +# then copy and paste above command to other nodes + +Case 3: multi-node, specified master/rendezvous server + +# python -m paddle.distributed.run --master 10.0.0.1:13538 --server_num=2 --trainer_num=2 --np 2 train.py +# the master ip must be one of the node and the port must available + +Case 4: specified servers and trainers in each node + +python -m paddle.distributed.run --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903 train.py + + +Elastic Mode + +# run following command in 3 node to run immediately, or in 2 node to run after elastic_timeout +# python -m paddle.distributed.run --master etcd://10.0.0.1:2379 --np 2:3 train.py + +# once the peer number changes between 2:3, the strategy holds + +''' diff --git a/python/paddle/distributed/run/__main__.py b/python/paddle/distributed/run/__main__.py new file mode 100644 index 0000000000000..e32df59a32808 --- /dev/null +++ b/python/paddle/distributed/run/__main__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .context import Context +from . import controllers + +# initialize the context to run +ctx = Context() + +# initialize the selected controller +c = controllers.init(ctx) + +# run the pods +c.run() + +# manager or just wait pod +c.finalize() diff --git a/python/paddle/distributed/run/context/__init__.py b/python/paddle/distributed/run/context/__init__.py new file mode 100644 index 0000000000000..86dff0f1f8056 --- /dev/null +++ b/python/paddle/distributed/run/context/__init__.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser, REMAINDER +import os, copy + +from paddle.distributed.run import plugins + +from .node import Node +from .status import Status + +import logging + + +class Context(object): + def __init__(self, enable_plugin=True): + os.environ.pop('http_proxy', None) + os.environ.pop('https_proxy', None) + + self.args = self.parse_args() + self.envs = self.fetch_envs() + self.logger = self.get_logger() + + self.node = Node() + self.status = Status() + + self.set_env_in_args() + + # design for event queue, later + self.events = [] + + if enable_plugin: + self._enable_plugin() + + def get_envs(self): + return self.envs.copy() + + def _enable_plugin(self): + for pl in plugins.enabled_plugins: + pl(self) + + def parse_args(self): + parser = ArgumentParser() + + base_group = parser.add_argument_group("Base Parameters") + + base_group.add_argument( + "--master", + type=str, + default=None, + help="the master/rendezvous server, ip:port") + + base_group.add_argument( + "--rank", type=int, default=-1, help="the peer rank") + + base_group.add_argument( + "--log", type=str, default="INFO", help="log level. Default INFO") + + base_group.add_argument( + "--np", + type=str, + default="1", + help="the number of peers, i.e. pod/node number") + + base_group.add_argument( + "--nproc_per_node", + type=int, + default=None, + help="the number of processes in a pod") + + base_group.add_argument( + "--log_dir", + type=str, + default="log", + help="the path for each process's log. Default ./log") + base_group.add_argument( + "--mode", + type=str, + default="collective", + help="run mode of the job, collective/ps/ps-heter") + + base_group.add_argument( + "--id", + type=str, + default="default", + help="unique id of the job. Default default") + + base_group.add_argument( + "--devices", + type=str, + default=None, + help="accelerate devices. as --gpus,npus,xps") + + base_group.add_argument( + "--host", type=str, default=None, help="host ip") + + base_group.add_argument( + "training_script", + type=str, + help="the full path of py script," + "followed by arguments for the " + "training script") + + base_group.add_argument('training_script_args', nargs=REMAINDER) + + ps_group = parser.add_argument_group("Parameter-Server Parameters") + # for parameter server + ps_group.add_argument( + "--servers", + type=str, + default='', + help="servers endpoints full list") + ps_group.add_argument( + "--trainers", + type=str, + default='', + help="trainers endpoints full list") + + ps_group.add_argument( + "--trainer_num", type=int, default=None, help="number of trainers") + ps_group.add_argument( + "--server_num", type=int, default=None, help="number of servers") + ps_group.add_argument( + "--gloo_port", type=int, default=6767, help="gloo http port") + ps_group.add_argument( + "--with_gloo", type=str, default="0", help="use gloo or not") + + # parameter elastic mode + elastic_group = parser.add_argument_group("Elastic Parameters") + elastic_group.add_argument( + "--max_restart", + type=int, + default=3, + help="the times can restart. Default 3") + + elastic_group.add_argument( + "--elastic_level", + type=int, + default=-1, + help="elastic level: -1 disable, 0 failed exit, peers hold, 1 internal restart" + ) + + elastic_group.add_argument( + "--elastic_timeout", + type=int, + default=30, + help="seconds to wait before elastic perform training") + return parser.parse_args() + + def _valide_env(self, key): + if key in ['POD_IP']: + return True + if key.endswith('_VISIBLE_DEVICES'): + return True + if key.startswith('PADDLE_'): + return True + + return False + + def fetch_envs(self): + ge = os.environ.copy() + + black_env_list = ['http_proxy', 'https_proxy'] + for key in black_env_list: + ge.pop(key, None) + + return ge + ''' + # use black list instead white list + return {k: ge[k] for k in ge if self._valide_env(k)} + ''' + + def get_logger(self, level=logging.INFO): + logger = logging.getLogger("PADDLERUN") + logger.setLevel(self.args.log.upper() or level) + formatter = logging.Formatter( + fmt='%(name)s %(levelname)s %(asctime)s %(message)s') + ch = logging.StreamHandler() + ch.setFormatter(formatter) + logger.addHandler(ch) + return logger + + def set_env_in_args(self): + env_args = { + 'POD_IP': 'host', + 'PADDLE_MASTER': 'master', + 'PADDLE_DEVICES': 'devices', + 'PADDLE_NP': 'np', + 'PADDLE_MODE': 'mode', + 'PADDLE_LOG': 'log', + 'PADDLE_NPROC_PER_NODE': 'nproc_per_node', + 'PADDLE_JOB_ID': 'id', + 'PADDLE_RANK': 'rank', + 'PADDLE_LOG_DIR': 'log_dir', + 'PADDLE_MAX_RESTlRT': 'max_restart', + 'PADDLE_ELASTIC_LEVEL': 'elastic_level', + 'PADDLE_ELASTIC_TIMEOUT': 'elastic_timeout', + 'PADDLE_SERVER_NUM': 'server_num', + 'PADDLE_TRAINER_NUM': 'trainer_num', + 'PADDLE_SERVERS_ENDPOINTS': 'servers', + 'PADDLE_TRAINERS_ENDPOINTS': 'trainers', + 'PADDLE_GLOO_PORT': 'gloo_port', + 'PADDLE_WITH_GLOO': 'with_gloo', + } + + for k, v in env_args.items(): + if k in self.envs: + setattr(self.args, v, self.envs[k]) diff --git a/python/paddle/distributed/run/context/device.py b/python/paddle/distributed/run/context/device.py new file mode 100644 index 0000000000000..d8bbd851ccf83 --- /dev/null +++ b/python/paddle/distributed/run/context/device.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + + +class DeviceType: + CPU = 'cpu' + GPU = 'gpu' + XPU = 'xpu' + NPU = 'npu' + + +class Device(object): + def __init__(self, dtype=None, count=1, memory="", labels=""): + self.dtype = dtype + self.count = count + self.memory = memory + self.labels = labels + + def __str__(self): + return ",".join(self.labels) + + @classmethod + def parse_device(self): + dev = Device() + visible_devices = None + if 'CUDA_VISIBLE_DEVICES' in os.environ or 'NVIDIA_VISIBLE_DEVICES' in os.environ: + dev.dtype = DeviceType.GPU + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv( + "NVIDIA_VISIBLE_DEVICES") + elif 'XPU_VISIBLE_DEVICES' in os.environ: + dev.dtype = DeviceType.XPU + visible_devices = os.getenv("XPU_VISIBLE_DEVICES") + elif 'ASCEND_VISIBLE_DEVICES' in os.environ: + dev.dtype = DeviceType.NPU + visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + + if visible_devices and visible_devices != 'all': + dev.labels = visible_devices.split(',') + dev.count = len(dev.labels) + else: + return self.detect_device() + + return dev + + @classmethod + def detect_device(self): + import paddle.fluid as fluid + + dev = Device() + num = 0 + visible_devices = None + if fluid.core.is_compiled_with_cuda(): + dev.dtype = DeviceType.GPU + num = fluid.core.get_cuda_device_count() + visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") or os.getenv( + "NVIDIA_VISIBLE_DEVICES") + elif fluid.core.is_compiled_with_xpu(): + dev.dtype = DeviceType.XPU + num = fluid.core.get_xpu_device_count() + visible_devices = os.getenv("XPU_VISIBLE_DEVICES") + elif fluid.core.is_compiled_with_npu(): + dev.dtype = DeviceType.NPU + num = fluid.core.get_npu_device_count() + visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + + if num == 0: + dev.dtype = DeviceType.CPU + elif visible_devices is None or visible_devices == "all" or visible_devices == "": + dev.labels = [str(x) for x in range(0, num)] + dev.count = num + else: + dev.labels = visible_devices.split(',') + dev.count = len(dev.labels) + + return dev diff --git a/python/paddle/distributed/run/context/event.py b/python/paddle/distributed/run/context/event.py new file mode 100644 index 0000000000000..23e8e7a501400 --- /dev/null +++ b/python/paddle/distributed/run/context/event.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Event(object): + def __init__(self, kind="status", message="", fatal=False): + self.kind = kind + self.message = message + self.fatal = fatal diff --git a/python/paddle/distributed/run/context/node.py b/python/paddle/distributed/run/context/node.py new file mode 100644 index 0000000000000..1ece4db0fbbee --- /dev/null +++ b/python/paddle/distributed/run/context/node.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .device import Device + +import socket +import struct +from contextlib import closing + + +class Node(object): + def __init__(self): + # self.device = Device.detect_device() + self.device = Device.parse_device() + self.ip = self.get_host_ip() + self.free_ports = [] + + def get_host_ip(self): + try: + self.hostname = socket.gethostname() + self.ip = socket.gethostbyname(socket.getfqdn(self.hostname)) + return self.ip + except: + return '127.0.0.1' + + def get_free_ports(self, n=1): + free_ports = [self.get_free_port() for i in range(n)] + self.free_ports += free_ports + return free_ports + + def get_ports_occupied(self): + return self.free_ports + + @classmethod + def get_free_port(self): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, + struct.pack('ii', 1, 0)) + s.bind(('', 0)) + return s.getsockname()[1] + + @classmethod + def is_server_ready(self, ip, port): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + #sock.settimeout(0.01) + #sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, 'SO_REUSEPORT'): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + result = sock.connect_ex((ip, int(port))) + if result == 0: + return True + else: + return False diff --git a/python/paddle/distributed/run/context/resource.py b/python/paddle/distributed/run/context/resource.py new file mode 100644 index 0000000000000..faffed704c1f0 --- /dev/null +++ b/python/paddle/distributed/run/context/resource.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Resource(object): + def __init__(self): + self.devices = [] diff --git a/python/paddle/distributed/run/context/status.py b/python/paddle/distributed/run/context/status.py new file mode 100644 index 0000000000000..cfbf3623ec22e --- /dev/null +++ b/python/paddle/distributed/run/context/status.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Status(object): + UNINIT = "uninit" + READY = "ready" + RUNNING = "running" + FAILED = "failed" + TERMINATING = "terminating" + RESTARTING = "restarting" + UNKNOWN = "unknown" + COMPLETED = "completed" + DONE = "done" # should exit whatever status + + def __init__(self): + self._current_status = None + + def current(self): + return self._current_status + + def is_running(self): + return self._current_status == self.RUNNING + + def is_restarting(self): + return self._current_status == self.RESTARTING + + def is_done(self): + if self._current_status in [self.DONE, self.COMPLETED, self.FAILED]: + return True + else: + return False + + def run(self): + self._current_status = self.RUNNING + + def fail(self): + self._current_status = self.FAILED + + def complete(self): + self._current_status = self.COMPLETED + + def restart(self): + self._current_status = self.RESTARTING + + def done(self): + self._current_status = self.DONE diff --git a/python/paddle/distributed/run/controllers/__init__.py b/python/paddle/distributed/run/controllers/__init__.py new file mode 100644 index 0000000000000..e5557151ad548 --- /dev/null +++ b/python/paddle/distributed/run/controllers/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["init"] + +from .collective import CollectiveController +from .collective import CollectiveElasticController +from .ps import PSController + +# the order is extremely important +_controllers = [ + CollectiveElasticController, + PSController, + CollectiveController, +] + + +def init(ctx): + for c in _controllers: + if c.enable(ctx): + return c(ctx) diff --git a/python/paddle/distributed/run/controllers/collective.py b/python/paddle/distributed/run/controllers/collective.py new file mode 100644 index 0000000000000..c4feb54428a07 --- /dev/null +++ b/python/paddle/distributed/run/controllers/collective.py @@ -0,0 +1,185 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .controller import Controller + +import json +import os +import six +import time + + +class CollectiveController(Controller): + @classmethod + def enable(cls, ctx): + if ctx: + ctx.logger.debug("{} enabled".format(cls.__name__)) + return True + else: + return False + + def build_pod(self): + self.pod.replicas = self.pod_replicas() + + # rank will be reset when restart + self.pod.rank = self.ctx.args.rank + + port = self.ctx.node.get_free_port() + + # compatible + endpoints = [ + "{}:{}".format(self.ctx.node.ip, p) + for p in self.ctx.node.get_free_ports(self.pod.replicas) + ] + + data = json.dumps({ + 'name': self.pod.name, + 'rank': self.pod.rank, + 'replicas': self.pod.replicas, + 'dtype': self.ctx.node.device.dtype, + 'candidate': '{}:{}'.format(self.ctx.node.ip, port), + 'endpoints': ",".join(endpoints), + }) + + peer_list, rank = self.master.sync_peers( + '/{}/info'.format(self.job.id), self.pod.name, data, + self.job.replicas, self.pod.rank) + self.pod.rank = rank + + if len(peer_list) < 1: + return False + + peer_list = [json.loads(i) for i in peer_list] + + self.ctx.logger.debug("sync peers done {}".format(peer_list)) + self.save_pod_log(peer_list) + + global_size = sum([i['replicas'] for i in peer_list]) + rank_offset = sum([i['replicas'] for i in peer_list[:rank]]) + ''' + The new designed collective need nothing but a master endpoint + ''' + collective_master = peer_list[0]['candidate'] + + job_endpoints = [i['endpoints'] for i in peer_list] + + self.pod.reset() + for i in range(self.pod.replicas): + e = { + "PADDLE_MASTER": collective_master, + "PADDLE_GLOBAL_SIZE": "{}".format(global_size), + "PADDLE_LOCAL_SIZE": "{}".format(self.pod.replicas), + "PADDLE_GLOBAL_RANK": "{}".format(i + rank_offset), + "PADDLE_LOCAL_RANK": "{}".format(i), + ## compatible env + "PADDLE_TRAINER_ENDPOINTS": ",".join(job_endpoints), + "PADDLE_CURRENT_ENDPOINT": endpoints[i], + "PADDLE_TRAINER_ID": "{}".format(i + rank_offset), + "PADDLE_TRAINERS_NUM": "{}".format(global_size), + "PADDLE_RANK_IN_NODE": str(i), + } + self.add_container(envs=e, log_tag=i) + + return True + + +class CollectiveElasticController(CollectiveController): + @classmethod + def enable(cls, ctx): + if ctx.args.master and ctx.args.master.startswith("etcd://"): + ctx.logger.debug("{} enabled".format(cls.__name__)) + return True + else: + return False + + def register(self): + if self.job.id == 'default': + self.ctx.logger.warning( + 'Using default job name may cause conflict, add --id in args') + + self.master.register_heartbeat(self.job.id, self.pod.name) + + def watch(self) -> bool: + ''' + watch self and peer status, return true to exit + ''' + while not self.ctx.status.is_done(): + # self status + status = self.pod.watch(timeout=2) + self.ctx.logger.debug("Pod status {}, Ctx status {}".format( + status, self.ctx.status.current())) + + # completed + if status == self.ctx.status.COMPLETED: + self.master.set_status(status) + self.ctx.status.complete() + self.ctx.logger.info("Pod complete {}".format(status)) + return True + + # self failure + elif status == self.ctx.status.FAILED: + self.master.set_status(status) + self.master.restart_peer() + self.ctx.logger.info("Pod failed {}".format(status)) + self.pod.stop() + + if self.ctx.args.elastic_level <= 0: + return True + else: + return False + + # peer failure + if self.ctx.status.is_restarting() and self.master.get_status( + ) != self.ctx.status.COMPLETED: + self.pod.stop() + return False + + #peers = self.master.fetch_peer_alive() + #print("peers {}".format(peers)) + + def run(self): + + timeout = self.ctx.args.elastic_timeout if self.job.elastic else self.ctx.args.elastic_timeout * 10 + self.register() + + while self.pod.restart <= self.ctx.args.max_restart: + + self.build_job() + + ok, replicas = self.master.wait_peer_ready( + self.job.replicas_min, self.job.replicas_max, timeout) + if ok: + self.job.replicas = replicas + else: + self.ctx.logger.warnning("peer not ready {}".format(self.job)) + break + + self.ctx.logger.debug("Run {}".format(self.job)) + + if not self.build_pod(): + continue + + self.master.set_status(self.ctx.status.RUNNING) + self.ctx.status.run() + + assert len(self.pod.containers) > 0, "No container in the pod" + self.ctx.logger.debug("Run {}".format(self.pod)) + self.ctx.logger.debug("Run {}".format(self.pod.containers[0])) + + self.pod.deploy() + + if self.watch(): + break + + self.ctx.logger.debug("Job done {}".format(self.job)) diff --git a/python/paddle/distributed/run/controllers/controller.py b/python/paddle/distributed/run/controllers/controller.py new file mode 100644 index 0000000000000..2d904cf2a2cca --- /dev/null +++ b/python/paddle/distributed/run/controllers/controller.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import signal + +from paddle.distributed.run.job import Job +from paddle.distributed.run.job import Pod +from paddle.distributed.run.job import Container + +from .master import Master + +import time + + +class ControleMode: + COLLECTIVE = "collective" + PS = "ps" + + +class ControllerBase(object): + def __init__(self, ctx): + signal.signal(signal.SIGTERM, self.signal_handler) + signal.signal(signal.SIGABRT, self.signal_handler) + signal.signal(signal.SIGINT, self.signal_handler) + + self.ctx = ctx + self.master = Master.factory(self.ctx) + + self.job = Job(np=self.ctx.args.np, + mode=self.ctx.args.mode, + id=self.ctx.args.id) + self.pod = Pod() + + self.join_server = None + + def run(self): + self.build_job() + self.build_pod() + + if len(self.pod.containers) < 1: + self.ctx.logger.error("No container in the pod {}".format(self.pod)) + return + + self.ctx.logger.info("Run {}".format(self.pod)) + self.ctx.logger.debug(self.pod.containers[0]) + + self.pod.deploy() + + self.watch() + + def watch(self) -> bool: + status = self.pod.watch() + + if status == self.ctx.status.COMPLETED: + self.ctx.logger.info("Pod {}".format(status)) + elif status == self.ctx.status.FAILED: + self.ctx.logger.info("Pod {}".format(status)) + self.ctx.logger.error("Container failed !!!\n{}".format( + self.pod.failed_container())) + self.pod.tail() + self.pod.stop() + + def stop(self, sigint=None): + self.ctx.logger.debug("Controller stop") + self.master.stop() + self.pod.stop(sigint) + + def finalize(self): + self.pod.join() + self.master.stop() + + self.ctx.logger.info("Exit code {}".format(self.pod.exit_code)) + sys.exit(self.pod.exit_code) + + def signal_handler(self, sigint, frame): + self.ctx.logger.info("Terminating with signal {}".format(sigint)) + + if hasattr(self, 'sigint'): + time.sleep(5) + sys.exit(sigint) + + self.sigint = sigint + self.ctx.status.done() + self.stop(sigint) + time.sleep(1) + self.ctx.logger.debug("Exit with signal {}".format(sigint)) + sys.exit(sigint) + + +class Controller(ControllerBase): + ''' + Controller API for customization + ''' + + def build_job(self): + ''' + build job fill the job info. + ''' + self.ctx.logger.info(self.job) + + def build_pod(self) -> bool: + ''' + build pod includes creating containers etc. + + Return True if succeed + ''' + raise NotImplementedError + + def _get_entrypoint(self): + entrypoint = [sys.executable, "-u", self.ctx.args.training_script] + entrypoint.extend(self.ctx.args.training_script_args) + return entrypoint + + def _get_out_err_file(self, out=None, err=None): + if out and self.ctx.args.log_dir != "": + out = os.path.join(self.ctx.args.log_dir, out) + if err and self.ctx.args.log_dir != "": + err = os.path.join(self.ctx.args.log_dir, err) + return out, (err or out) + + def new_container(self, + entrypoint=None, + envs={}, + use_ctx_env=True, + out=None, + err=None): + c = Container( + entrypoint=(entrypoint or self._get_entrypoint()), + env=(self.ctx.get_envs() if use_ctx_env else {}), ) + c.outfile, c.errfile = self._get_out_err_file(out, err) + c.update_env(envs) + return c + + def add_container(self, + container=None, + entrypoint=None, + envs={}, + log_tag=None, + is_init=False): + if not is_init and log_tag is not None: + log_file = "{}.{}.{}.log".format(self.job.id, self.pod.name, + log_tag) + else: + log_file = None + + if not container: + container = self.new_container( + entrypoint=entrypoint, envs=envs, out=log_file, err=log_file) + + if is_init: + self.pod.add_init_container(container) + else: + self.pod.add_container(container) + + def pod_replicas(self): + ''' + how many process/container should be run in pod + ''' + + if self.ctx.args.nproc_per_node: + return int(self.ctx.args.nproc_per_node) + else: + return self.ctx.node.device.count + + def save_pod_log(self, info): + ''' + save_pod_log append *info* to the log file of pod.name + ''' + if not self.ctx.args.log_dir: + return + + f = os.path.join(self.ctx.args.log_dir, + '{}.{}.log'.format(self.job.id, self.pod.name)) + try: + os.makedirs(os.path.dirname(f), exist_ok=True) + with open(f, 'a+') as fd: + fd.write(str(info)) + except Exception as e: + self.ctx.logger.error("save log failed because {}".format(e)) diff --git a/python/paddle/distributed/run/controllers/master.py b/python/paddle/distributed/run/controllers/master.py new file mode 100644 index 0000000000000..257ba3bad8da3 --- /dev/null +++ b/python/paddle/distributed/run/controllers/master.py @@ -0,0 +1,289 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.run.utils.kv_client import KVClient +from paddle.distributed.run.utils.kv_server import KVServer + +import time +import sys +import six +import threading +import copy +import random + +ETCD_PROTOCAL = 'etcd://' + + +class Master(object): + ''' + Master is a distributed store design to exchange info among nodes + ''' + + MAIN = "main" + STANDBY = "standby" + PATICIPANT = "participant" + + def __init__(self, ctx): + self.ctx = ctx + self.server = None + self.initialized = False + self.endpoint = None + + def stop(self): + raise NotImplementedError + + def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): + raise NotImplementedError + + @classmethod + def factory(cls, ctx): + if ctx.args.master and ctx.args.master.startswith(ETCD_PROTOCAL): + return ETCDMaster(ctx) + else: + return HTTPMaster(ctx) + + +class HTTPMaster(Master): + def lazy_init(self): + if self.initialized: + return + + self.role = Master.PATICIPANT + + if self.ctx.args.master: + self.endpoint = self.ctx.args.master + ip, port = self.endpoint.split(':') + if ip in ['127.0.0.1', self.ctx.node.ip]: + time.sleep(2 * random.random()) + while not self.ctx.node.is_server_ready(ip, int(port)): + try: + self.server = KVServer(int(port)) + self.role = Master.MAIN + break + except Exception as e: + self.ctx.logger.warning("start master failed {}".format( + e)) + time.sleep(0.1) + continue + else: + port = self.ctx.node.get_free_port() + self.endpoint = "{}:{}".format(self.ctx.node.ip, port) + self.server = KVServer(port) + self.role = Master.MAIN + + print("Copy the following command to other nodes to run.") + cmd = [ + sys.executable.split('/')[-1], "-m", "paddle.distributed.run" + ] + cmd.extend(["--master", self.endpoint]) + cmd.extend(sys.argv[1:]) + print("-" * 80) + print(" ".join(cmd)) + print("-" * 80) + + if self.ctx.args.rank >= 0: + self.ctx.logger.warning( + "--rank set in the command may not compatible in auto mode") + + if '127.0.0.1' in self.endpoint: + self.endpoint = self.endpoint.replace('127.0.0.1', self.ctx.node.ip) + self.client = KVClient(self.endpoint) + + self.initialized = True + + self._start_server() + + def _start_server(self): + if self.server and not self.server.started: + self.server.start() + self.ctx.logger.debug("KV server start at {}".format(self.endpoint)) + + def _stop_server(self): + if self.server and not self.server.stopped: + self.server.stop() + self.ctx.logger.debug("KV server stopped") + + def stop(self): + self._stop_server() + + def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): + if size < 2: + return [value], 0 + + self.lazy_init() + + while not self.ctx.status.is_done(): + if self.client.wait_server_ready(timeout=5): + break + else: + self.ctx.logger.warning("master not ready") + time.sleep(0.1) + + # 'aaaaaa' make suer main pod (master server) as rank 0 + ky = 'aaaaaa' if rank < 0 and self.role == Master.MAIN else key + k = "{}/{}/{}".format(prefix, ky, rank) + + while not self.ctx.status.is_done(): + if not self.client.put(k, value): + self.ctx.logger.warning("put value failed") + time.sleep(0.1) + continue + + rjson = self.client.get_prefix(prefix) + self.ctx.logger.debug("sync peers {}".format(rjson)) + if rjson and len(rjson) == size: + if rank < 0: + keys = list(rjson.keys()) + keys.sort() + ret = [rjson[k] for k in keys] + idx = ret.index(value) + return ret, idx + else: + ret = [None] * size + for k, v in rjson.items(): + ret[int(k.split('/')[-1])] = v + return ret, rank + else: + time.sleep(0.5) + return [], 0 + + +class ETCDMaster(Master): + def __init__(self, ctx): + super().__init__(ctx) + + if self.ctx.args.master: + # etcd://localhost:2379 + self.endpoint = self.ctx.args.master.strip("etcd://") + + import etcd3 + + host, port = self.endpoint.split(':') + self.client = etcd3.client(host=host, port=port) + + def sync_peers(self, prefix, key, value, size, rank=-1) -> (list, int): + ''' + sync_peers gather all value for key under scope prefix + result always be sorted either by rank or alphabet of pod.name + ''' + path = "{}/{}/{}".format(prefix, key, rank) + + self.client.delete_prefix(prefix) + + self.ctx.logger.debug("sync path {} value {}".format(path, value)) + + while not self.ctx.status.is_done(): + self.client.put(path, six.b(value)) + + result = [i for i in self.client.get_prefix(prefix)] + result = copy.deepcopy(result) + self.ctx.logger.debug("sync peers {}".format(result)) + + if len(result) == size: + if rank < 0: + keys = [six.ensure_str(i[1].key) for i in result] + sorted_keys = [six.ensure_str(i[1].key) for i in result] + sorted_keys.sort() + values = [six.ensure_str(i[0]) for i in result] + ret = [values[keys.index(k)] for k in sorted_keys] + idx = ret.index(value) + return ret, idx + else: + ret = [None] * size + for v, k in result: + ii = int(six.ensure_str(k.key).split('/')[-1]) + if ii < 0: + self.ctx.logger.error( + "rank {} error in sync".format(ii)) + ret[ii] = six.ensure_str(v) + return ret, rank + else: + time.sleep(0.5) + + def register_heartbeat(self, job_id, pod_id, ttl=10): + if hasattr(self, 'heartbeat_prefix'): + self.ctx.logger.warning("Heartbeat already done") + return + + self.job_prefix = '/paddle/{}'.format(job_id) + self.heartbeat_prefix = '{}/heartbeat'.format(self.job_prefix) + + lease = self.client.lease(ttl) + + #self.client.delete_prefix(self.job_prefix) + + beat_path = "{}/{}".format(self.heartbeat_prefix, pod_id) + self.client.put(beat_path, six.b(pod_id), lease=lease) + + def _beat_watch(event): + self.ctx.status.restart() + + beat_watch = self.client.add_watch_prefix_callback( + self.heartbeat_prefix, _beat_watch) + + def _heartbeat(): + while not self.ctx.status.is_done(): + try: + lease.refresh() + if pod_id not in self.fetch_peer_alive(): + self.client.put(beat_path, six.b(pod_id), lease=lease) + self.ctx.logger.debug("Heartbeat register again") + except Exception as e: + self.ctx.logger.error("Heartbeat error {}".format(e)) + time.sleep(ttl / 2) + self.ctx.logger.debug("Heartbeat done") + self.client.cancel_watch(beat_watch) + + self.beat_thread = threading.Thread( + name='heartbeat', target=_heartbeat, daemon=True) + self.beat_thread.start() + + def fetch_peer_alive(self): + peer_alive = [ + six.ensure_str(i[0]) + for i in self.client.get_prefix(self.heartbeat_prefix) + ] + self.ctx.logger.debug("peer alive {}".format(peer_alive)) + return peer_alive + + def wait_peer_ready(self, replicas_min, replicas_max, timeout): + end = time.time() + timeout + while not self.ctx.status.is_done() and time.time() < end: + if len(self.fetch_peer_alive()) == replicas_max: + return (True, replicas_max) + else: + time.sleep(0.5) + + np = len(self.fetch_peer_alive()) + if np >= replicas_min and np <= replicas_max: + return (True, np) + else: + return (False, np) + + def restart_peer(self): + self.client.delete_prefix(self.heartbeat_prefix) + + def set_status(self, status): + assert self.client.put( + self.job_prefix, six.b(status), + lease=self.client.lease(600)), "set status failed {}".format(status) + + def get_status(self): + return six.ensure_str(self.client.get(self.job_prefix)[0] or '') + + def stop(self): + if hasattr(self, 'beat_thread'): + self.ctx.status.done() + # TODO(kuizhiqing) thread should exit + #self.beat_thread.join() diff --git a/python/paddle/distributed/run/controllers/ps.py b/python/paddle/distributed/run/controllers/ps.py new file mode 100644 index 0000000000000..cc43c336cf186 --- /dev/null +++ b/python/paddle/distributed/run/controllers/ps.py @@ -0,0 +1,221 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .controller import Controller, ControleMode + +import json +import os, shutil + + +class PSController(Controller): + @classmethod + def enable(cls, ctx): + if ctx.args.mode == ControleMode.PS or ctx.args.server_num or len( + ctx.args.servers) > 0: + ctx.logger.debug("{} enabled".format(cls.__name__)) + ctx.args.mode = ControleMode.PS + return True + else: + return False + + def build_pod(self): + if self.ctx.args.servers and self.ctx.args.trainers: + self._build_pod_with_args() + else: + self._build_pod_with_master() + + def _build_pod_with_args(self): + if '127.0.0.1' in self.ctx.args.servers: + host = '127.0.0.1' + else: + host = self.ctx.node.ip + + server_endpoints = [s for s in self.ctx.args.servers.split(",")] + trainer_endpoints = [s for s in self.ctx.args.trainers.split(",")] + servers = [ + s for s in self.ctx.args.servers.split(",") if s.startswith(host) + ] + trainers = [ + s for s in self.ctx.args.trainers.split(",") if s.startswith(host) + ] + server_num = len(servers) + trainer_num = len(trainers) + + self.pod.replicas = server_num + trainer_num + + self.save_pod_log([server_endpoints, trainer_endpoints]) + + import tempfile + gloo_rendezvous_dir = tempfile.mkdtemp() + if os.path.exists(gloo_rendezvous_dir): + shutil.rmtree(gloo_rendezvous_dir) + + gloo_port = self.ctx.args.gloo_port + gloo_http = "{}:{}".format(server_endpoints[0].split(":")[0], gloo_port) + + _gloo_envs = { + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": gloo_rendezvous_dir, + "PADDLE_GLOO_HTTP_ENDPOINT": gloo_http, + "PADDLE_WITH_GLOO": self.ctx.args.with_gloo + } + + for i in range(server_num): + e = { + "PADDLE_PSERVERS_IP_PORT_LIST": self.ctx.args.servers, + "PADDLE_TRAINER_ENDPOINTS": self.ctx.args.trainers, + "PADDLE_PORT": servers[i].split(":")[1], + "PADDLE_ROLE": "PSERVER", + "TRAINING_ROLE": "PSERVER", + "PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)), + "POD_IP": self.ctx.node.ip, + } + e.update(_gloo_envs) + log_tag = "ps.{}".format(i) + self.add_container(envs=e, log_tag=log_tag) + + trainer_rank_offset = 0 + for s in trainer_endpoints: + if s.startswith(host): + break + else: + trainer_rank_offset += 1 + + for i in range(trainer_num): + e = { + "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), + "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), + "PADDLE_PORT": trainers[i].split(":")[1], + "PADDLE_ROLE": "TRAINER", + "TRAINING_ROLE": "TRAINER", + "PADDLE_TRAINER_ID": "{}".format(i + trainer_rank_offset), + "PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)), + "POD_IP": self.ctx.node.ip, + } + e.update(_gloo_envs) + log_tag = "trainer.{}".format(i) + self.add_container(envs=e, log_tag=log_tag) + + def _build_pod_with_master(self): + + self.pod.rank = self.ctx.args.rank + + server_num = self.ctx.args.server_num or 1 + servers = [ + "{}:{}".format(self.ctx.node.ip, p) + for p in self.ctx.node.get_free_ports(server_num) + ] + trainer_num = self.ctx.args.trainer_num or 1 + trainers = [ + "{}:{}".format(self.ctx.node.ip, p) + for p in self.ctx.node.get_free_ports(trainer_num) + ] + + data = json.dumps({ + 'name': self.pod.name, + 'rank': self.pod.rank, + 'servers': servers, + 'trainers': trainers, + 'dtype': self.ctx.node.device.dtype, + 'gloo_port': self.ctx.node.get_free_port(), + }) + + peer_list, rank = self.master.sync_peers( + '/{}/info'.format(self.job.id), self.pod.name, data, + self.job.replicas, self.pod.rank) + + self.ctx.logger.debug("sync peers done {}".format(peer_list)) + + peer_list = [json.loads(i) for i in peer_list] + + self.save_pod_log(peer_list) + + server_endpoints = [j for i in peer_list for j in i['servers']] + trainer_endpoints = [j for i in peer_list for j in i['trainers']] + #rank_offset = sum([i['replicas'] for i in peer_list[:rank]]) + + server_rank_offset = sum([len(i['servers']) for i in peer_list[:rank]]) + trainer_rank_offset = sum( + [len(i['trainers']) for i in peer_list[:rank]]) + + self.pod.rank = rank + + self.pod.replicas = server_num + trainer_num + + import tempfile + gloo_rendezvous_dir = tempfile.mkdtemp() + if os.path.exists(gloo_rendezvous_dir): + shutil.rmtree(gloo_rendezvous_dir) + + gloo_port = peer_list[0]['gloo_port'] + gloo_http = "{}:{}".format(server_endpoints[0].split(":")[0], gloo_port) + + _gloo_envs = { + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": gloo_rendezvous_dir, + "PADDLE_GLOO_HTTP_ENDPOINT": gloo_http, + "PADDLE_WITH_GLOO": self.ctx.args.with_gloo + } + + for i in range(server_num): + e = { + "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), + "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), + "PADDLE_PORT": + server_endpoints[i + server_rank_offset].split(":")[1], + "PADDLE_ROLE": "PSERVER", + "TRAINING_ROLE": "PSERVER", + "PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)), + "POD_IP": self.ctx.node.ip, + } + e.update(_gloo_envs) + log_tag = "ps.{}".format(i) + self.add_container(envs=e, log_tag=log_tag) + + for i in range(trainer_num): + e = { + "PADDLE_PSERVERS_IP_PORT_LIST": ",".join(server_endpoints), + "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), + "PADDLE_PORT": + trainer_endpoints[i + trainer_rank_offset].split(":")[1], + "PADDLE_ROLE": "TRAINER", + "TRAINING_ROLE": "TRAINER", + "PADDLE_TRAINER_ID": "{}".format(i + trainer_rank_offset), + "PADDLE_TRAINERS_NUM": "{}".format(len(trainer_endpoints)), + "POD_IP": self.ctx.node.ip, + } + e.update(_gloo_envs) + log_tag = "trainer.{}".format(i) + self.add_container(envs=e, log_tag=log_tag) + ''' NEW VERSION + for i in range(server_num): + e = { + "PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints), + "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), + "PADDLE_ROLE": "PSERVER", + "PADDLE_RANK": "{}".format(i + server_rank_offset), + } + log_tag = "ps.{}".format(i) + self.add_container(envs=e, log_tag=log_tag) + + for i in range(trainer_num): + e = { + "PADDLE_PSERVER_ENDPOINTS": ",".join(server_endpoints), + "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints), + "PADDLE_ROLE": "TRAINER_CPU", + "PADDLE_RANK": "{}".format(i + trainer_rank_offset), + } + log_tag = "trainer.{}".format(i) + self.add_container(envs=e, log_tag=log_tag) + ''' diff --git a/python/paddle/distributed/run/job/__init__.py b/python/paddle/distributed/run/job/__init__.py new file mode 100644 index 0000000000000..66d2abbce21eb --- /dev/null +++ b/python/paddle/distributed/run/job/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pod import Pod +from .job import Job +from .container import Container +from .status import Status + +__all__ = [ + 'Pod', + 'Job', + 'Container', + 'Status', +] diff --git a/python/paddle/distributed/run/job/container.py b/python/paddle/distributed/run/job/container.py new file mode 100644 index 0000000000000..651932d6c8837 --- /dev/null +++ b/python/paddle/distributed/run/job/container.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from paddle.distributed.run.utils.process_context import ProcessContext + +from .status import Status + +import os, copy, sys +import time + + +class Container(object): + ''' + TODO(kuizhiqing) A container can be run by process/thread or just a callable function + ''' + + def __init__(self, entrypoint=[], rank=-1, env={}): + self._entrypoint = entrypoint + self._rank = rank + self._out = None + self._err = None + self._env = env + self._proc = None + + self._retry: int = 3 + self._grace_period = 10 + + self._log_handler = None + + @property + def entrypoint(self): + return self._entrypoint + + @entrypoint.setter + def entrypoint(self, entry): + self._entrypoint = entry + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, r): + self._rank = r + + @property + def outfile(self): + return self._out + + @outfile.setter + def outfile(self, out): + self._out = out + + @property + def errfile(self): + return self._err + + @errfile.setter + def errfile(self, err): + self._err = err + + def update_env(self, env={}, **kwargs): + env = {k: v for k, v in env.items() if isinstance(v, str)} + self._env.update(env) + + kwargs = {k: v for k, v in kwargs.items() if isinstance(v, str)} + self._env.update(kwargs) + + def _get_fd(self, pth): + if not pth: + return None + + try: + d = os.path.dirname(pth) + if not os.path.isdir(d): + os.makedirs(d, exist_ok=True) + return open(pth, 'w') + except: + return None + + def start(self, timeout=-1): + end = time.time() + timeout + + if self._proc and self._proc.alive(): + return True + + self._stdout = self._get_fd(self._out) or sys.stdout + if self._out == self._err: + self._stderr = self._stdout + elif self._err: + self._stderr = self._get_fd(self._err) or sys.stderr + + self._proc = ProcessContext( + self._entrypoint, env=self._env, out=self._stdout, err=self._stderr) + self._proc.start() + + while timeout > 0 and time.time() < end: + if self._proc.alive(): + time.sleep(0.1) + continue + if self._proc.exit_code() == 0: + return True + return False + + def terminate(self, force=False): + if self._log_handler: + self._log_handler.close() + self._log_handler = None + + if self._proc and self._proc.alive(): + return self._proc.terminate(force) + + def wait(self, timeout=None): + self._proc.wait(timeout) + + def exit_code(self): + return self._proc.exit_code() if self._proc else -1 + + def status(self): + if not self._proc: + return Status.UNINIT + if self._proc.alive(): + return Status.RUNNING + elif self._proc.exit_code() == 0: + return Status.COMPLETED + else: + return Status.FAILED + + def __str__(self): + return 'Container rank {} status {} cmd {} code {} log {} \nenv {}'.format( + self._rank, + self.status(), + self._entrypoint, + self.exit_code(), + self.errfile, + self._env, ) + + def logs(self, fn=None, offset=0, whence=1, lines=1000): + if not self._log_handler: + self._log_handler = open(self._out) + + if fn is None: + fn = sys.stdout + + self._log_handler.seek(offset, whence) + + try: + idx = 0 + for line in self._log_handler: + fn.write(line) + idx += 1 + if idx > lines: + break + finally: + return self._log_handler.tell() + + def tail(self, length=3000): + if not self._log_handler: + self._log_handler = open(self._out) + + self._log_handler.seek(0, 2) + ed = self._log_handler.tell() + + if ed > length: + self.logs(offset=ed - length, whence=0) + else: + self.logs(offset=0, whence=0) diff --git a/python/paddle/distributed/run/job/job.py b/python/paddle/distributed/run/job/job.py new file mode 100644 index 0000000000000..3469ed862576f --- /dev/null +++ b/python/paddle/distributed/run/job/job.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class JobMode: + COLLECTIVE = 'collective' + PS = 'ps' + HETER = 'heter' + + +class Job(object): + def __init__(self, id='default', mode=JobMode.COLLECTIVE, np="1"): + self._mode = mode + self._id = id + + self._replicas = 0 + self._replicas_min = self._replicas + self._replicas_max = self._replicas + self._elastic = False + + self.set_replicas(str(np)) + + def __str__(self): + return "Job: {}, mode {}, replicas {}[{}:{}], elastic {}".format( + self.id, self.mode, self._replicas, self._replicas_min, + self._replicas_max, self.elastic) + + @property + def mode(self): + return self._mode + + @property + def id(self): + return self._id + + @property + def elastic(self): + return self._elastic + + @property + def replicas(self): + return self._replicas + + @property + def replicas_min(self): + return self._replicas_min + + @property + def replicas_max(self): + return self._replicas_max + + @replicas.setter + def replicas(self, replicas): + self._replicas = replicas + + def set_replicas(self, np: str): + np = str(np) if np else '1' + + if ':' in np: + nps = np.split(':') + self._replicas_min, self._replicas_max = int(nps[0]), int(nps[1]) + self._replicas = self._replicas_max # default to max + + self._elastic = True + else: + self._replicas = int(np) + self._replicas_min, self._replicas_max = self._replicas, self._replicas + + self._elastic = False diff --git a/python/paddle/distributed/run/job/pod.py b/python/paddle/distributed/run/job/pod.py new file mode 100644 index 0000000000000..f7c31edce1d55 --- /dev/null +++ b/python/paddle/distributed/run/job/pod.py @@ -0,0 +1,185 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from .container import Container + +from .status import Status + +import random +import time + + +class PodSepc(object): + def __init__(self): + self._name = ''.join( + random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(6)) + + # by controller + self._init_containers: List[Container] = [] + self._containers: List[Container] = [] + + #self.resource: Resource = None + #self.status: Status = None + + self._rank = -1 + self._init_timeout = 120 # 2 min timeout for each init container + self._restart = -1 + self._replicas = 0 # number of containers + self._exit_code = 0 + + +class Pod(PodSepc): + def __init__(self): + super().__init__() + + def __str__(self): + return "Pod: {}, replicas {}, status {}".format(self.name, + self.replicas, + self.status()) + + def failed_container(self): + for c in self._containers: + if c.status() == Status.FAILED: + return c + return None + + @property + def name(self): + return self._name + + @property + def replicas(self): + return self._replicas + + @replicas.setter + def replicas(self, r): + self._replicas = r + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, r): + self._rank = r + + @property + def restart(self): + return self._restart + + @property + def containers(self): + return self._containers + + def add_container(self, c): + c.rank = len(self._containers) + self._containers.append(c) + + @property + def init_containers(self): + return self._init_containers + + def add_init_container(self, c): + c.rank = len(self._init_containers) + self._init_containers.append(c) + + @property + def exit_code(self): + for c in self._containers: + if c.exit_code() != 0: + return c.exit_code() + return 0 + + def deploy(self): + for i in self._init_containers: + i.start(self._init_timeout) + + for c in self._containers: + c.start() + + self._restart += 1 + + def stop(self, sigint=0): + for c in self._containers: + force = True if sigint == 9 else False + c.terminate(force) + + def join(self): + for c in self._containers: + c.wait(None) + + def status(self): + if self.is_failed(): + return Status.FAILED + + if self.is_completed(): + return Status.COMPLETED + + return Status.READY + + def reset(self): + self._init_containers = [] + self._containers = [] + + def is_failed(self): + for c in self._containers: + if c.status() == Status.FAILED: + return True + return False + + def is_completed(self): + for c in self._containers: + if c.status() != Status.COMPLETED: + return False + return True + + def logs(self, idx=None): + if idx is None: + if self.failed_container(): + self.failed_container().logs() + else: + self._containers[0].logs() + else: + self._containers[idx].logs() + + def tail(self, idx=None): + if idx is None: + if self.failed_container(): + self.failed_container().tail() + else: + self._containers[0].tail() + else: + self._containers[idx].tail() + + def watch(self, + all_list=[Status.COMPLETED], + any_list=[Status.FAILED], + interval=1, + timeout=-1): + ''' + watch return if any container status in any_list + or all container status in all_list + ''' + end = time.time() + timeout + while timeout < 0 or time.time() < end: + for c in self._containers: + if c.status() in any_list: + return c.status() + + s = [c.status() for c in self._containers] + if len(set(s)) == 1 and s[0] in all_list: + return s[0] + + time.sleep(interval) diff --git a/python/paddle/distributed/run/job/status.py b/python/paddle/distributed/run/job/status.py new file mode 100644 index 0000000000000..ae10c5adb6cbf --- /dev/null +++ b/python/paddle/distributed/run/job/status.py @@ -0,0 +1,24 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Status(object): + UNINIT = "uninit" + READY = "ready" + RUNNING = "running" + FAILED = "failed" + TERMINATING = "terminating" + RESTARTING = "restarting" + UNKNOWN = "unknown" + COMPLETED = "completed" diff --git a/python/paddle/distributed/run/plugins/__init__.py b/python/paddle/distributed/run/plugins/__init__.py new file mode 100644 index 0000000000000..ec91402a7aad3 --- /dev/null +++ b/python/paddle/distributed/run/plugins/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +__all__ = [] + + +def log(ctx): + ctx.logger.info("----------- Configuration ----------------------") + for arg, value in sorted(six.iteritems(vars(ctx.args))): + ctx.logger.info("%s: %s" % (arg, value)) + ctx.logger.info("--------------------------------------------------") + + +def process_args(ctx): + # reset device by args + #argdev = ctx.args.gpus or ctx.args.xpus or ctx.args.npus + argdev = ctx.args.devices + if argdev: + ctx.node.device.labels = argdev.split(',') + ctx.node.device.count = len(ctx.node.device.labels) + ctx.logger.debug('Device reset by args {}'.format(argdev)) + + +def collective_compatible(ctx): + if 'PADDLE_TRAINER_ENDPOINTS' in ctx.envs: + ctx.master = ctx.envs['PADDLE_TRAINER_ENDPOINTS'].split(',')[0] + if 'DISTRIBUTED_TRAINER_ENDPOINTS' in ctx.envs: + ctx.master = ctx.envs['DISTRIBUTED_TRAINER_ENDPOINTS'].split(',')[0] + + +def rewrite_host_ip(ctx): + if ctx.args.host is not None and "." in ctx.args.host: + ctx.logger.warning('Host ip reset to {}'.format(ctx.args.host)) + ctx.node.ip = ctx.args.host + + +enabled_plugins = [collective_compatible, rewrite_host_ip, process_args, log] diff --git a/python/paddle/distributed/run/plugins/ip.py b/python/paddle/distributed/run/plugins/ip.py new file mode 100644 index 0000000000000..0809ed5864da9 --- /dev/null +++ b/python/paddle/distributed/run/plugins/ip.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket + + +def get_local_ip(ctx): + _, ip = _get_host_name_ip() + ctx.args.host = ip + ctx.envs["POD_IP"] = ip + + +def _get_host_name_ip(): + try: + host_name = socket.gethostname() + host_ip = socket.gethostbyname(host_name) + return host_name, host_ip + except: + return None diff --git a/python/paddle/distributed/run/utils/kv_client.py b/python/paddle/distributed/run/utils/kv_client.py new file mode 100644 index 0000000000000..e19195412268a --- /dev/null +++ b/python/paddle/distributed/run/utils/kv_client.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import time + + +class KVClient(object): + def __init__(self, endpoint='localhost:2379'): + self.endpoint = endpoint if endpoint.startswith( + "http://") else "http://{}".format(endpoint) + + def put(self, key, value): + key = key if key.startswith('/') else "/{}".format(key) + u = "{}{}".format(self.endpoint, key) + try: + r = requests.post(u, data=value, timeout=3) + if r.status_code == 200: + return True + else: + return False + except: + return False + + def get(self, key): + key = key if key.startswith('/') else "/{}".format(key) + u = "{}{}".format(self.endpoint, key) + try: + r = requests.get(u, timeout=3) + if r.status_code == 200: + ret = r.json() + return ret.get(key, '') + else: + return "error" + except: + return "" + + def get_prefix(self, key): + key = key if key.startswith('/') else "/{}".format(key) + u = "{}{}".format(self.endpoint, key) + try: + r = requests.get(u, timeout=3) + if r.status_code == 200: + return r.json() + except: + return "" + + def delete(self, key): + key = key if key.startswith('/') else "/{}".format(key) + u = "{}{}".format(self.endpoint, key) + try: + r = requests.delete(u, timeout=3) + if r.status_code == 200: + return True + else: + return False + except: + return False + + def wait_server_ready(self, timeout=3): + end = time.time() + timeout + while time.time() < end: + if self.get("/healthy") == "ok": + return True + + +if __name__ == '__main__': + cli = PKVClient("http://localhost:8090") + data = {"/workers/1": "rank1", "/workers/2": "rank2"} + for k, v in data.items(): + cli.put(k, v) + x = cli.get_prefix("/workers") + print(x) + for k, v in data.items(): + assert x[k] == v + + cli.put("key", "value") + print(cli.get("key")) + assert cli.get("key") == "value" + cli.delete("key") + print(cli.get("/key")) + print(cli.get("/healthy")) + assert cli.get("/healthy") == "ok" diff --git a/python/paddle/distributed/run/utils/kv_server.py b/python/paddle/distributed/run/utils/kv_server.py new file mode 100644 index 0000000000000..2d7ae15f13d63 --- /dev/null +++ b/python/paddle/distributed/run/utils/kv_server.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http.server import HTTPServer +import http.server as SimpleHTTPServer + +from multiprocessing import Process + +import threading +import json + + +class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + def do_GET(self): + with self.server.kv_lock: + ret = {} + for k, v in self.server.kv.items(): + if k.startswith(self.path): + ret[k] = v.decode(encoding="utf-8") + if ret: + self.output(200, json.dumps(ret).encode("utf-8")) + else: + self.output(404) + + def do_PUT(self): + self.do_POST() + + def do_POST(self): + content_length = int(self.headers['Content-Length'] or 0) + try: + value = self.rfile.read(content_length) + with self.server.kv_lock: + self.server.kv[self.path] = value + self.output(200) + return + except: + self.output(500) + + def do_DELETE(self): + with self.server.kv_lock: + if self.path in self.server.kv: + del self.server.kv[self.path] + self.output(200) + else: + self.output(404) + + def output(self, code, value=''): + self.send_response(code) + self.send_header("Content-Length", len(value)) + self.send_header("Content-Type", "application/json; charset=utf8") + self.end_headers() + if value: + self.wfile.write(value) + + def log_message(self, format, *args): + return + + +class KVServer(HTTPServer, object): + def __init__(self, port): + super(KVServer, self).__init__(('', port), KVHandler) + self.kv_lock = threading.Lock() + self.kv = {'/healthy': b'ok'} + self.port = port + self.stopped = False + self.started = False + + def start(self): + self.listen_thread = threading.Thread(target=self.serve_forever) + self.listen_thread.start() + self.started = True + + def stop(self): + self.shutdown() + self.listen_thread.join() + self.server_close() + self.stopped = True + + +class PKVServer(): + def __init__(self, port): + self._server = KVServer(port) + + def start(self): + self.proc = Process(target=self._server.start) + self.proc.daemon = True + self.proc.start() + + def stop(self): + self._server.stop() + self.proc.join() + + @property + def started(self): + return self._server.started + + @property + def stopped(self): + return self._server.stopped + + +if __name__ == '__main__': + #kv = PKVServer(8090) + kv = KVServer(8090) + kv.start() + import time + + #print("serve at 8090 for 600 s") + + time.sleep(600) diff --git a/python/paddle/distributed/run/utils/process_context.py b/python/paddle/distributed/run/utils/process_context.py new file mode 100644 index 0000000000000..4d6fa8de794ff --- /dev/null +++ b/python/paddle/distributed/run/utils/process_context.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import os, sys, signal, time + + +class ProcessContext(object): + def __init__(self, + cmd, + env=os.environ, + out=sys.stdout, + err=sys.stderr, + group=True, + preexec_fn=None): + self._cmd = cmd + self._env = env + self._preexec_fn = preexec_fn + self._stdout = out + self._stderr = err + self._group = group if os.name != 'nt' else False + self._proc = None + self._code = None + + def _start(self): + pre_fn = os.setsid if self._group else None + self._proc = subprocess.Popen( + self._cmd, + env=self._env, + stdout=self._stdout, + stderr=self._stderr, + preexec_fn=self._preexec_fn or pre_fn) + + def _close_std(self): + try: + if not self._stdout.isatty(): + self._stdout.close() + + if not self._stderr.isatty(): + self._stderr.close() + except: + pass + + def alive(self): + return self._proc and self._proc.poll() is None + + def exit_code(self): + return self._proc.poll() if self._proc else None + + def start(self): + self._start() + + def terminate(self, force=False, max_retry=3): + for i in range(max_retry): + if self.alive(): + if self._group: + os.killpg(os.getpgid(self._proc.pid), signal.SIGTERM) + else: + self._proc.terminate() + time.sleep(0.2) + else: + break + + if force and self.alive(): + self._proc.kill() + + self._close_std() + + return self.alive() + + def wait(self, timeout=None): + self._proc.wait(timeout) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f8102ec408043..1645273a0cdd4 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -944,6 +944,7 @@ if (WITH_DISTRIBUTE AND NOT APPLE) endif() # setting timeout value as 15S +set_tests_properties(test_run PROPERTIES TIMEOUT 200) set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cross_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_lod_tensor_to_selected_rows PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/tests/unittests/test_run.py b/python/paddle/fluid/tests/unittests/test_run.py new file mode 100644 index 0000000000000..8fe5fb9bb9455 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_run.py @@ -0,0 +1,174 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import subprocess +import sys, os +import json +import shutil + +import random + +from os import listdir +from os.path import isfile, join + +pyname = 'train.py' +colpyfile = '''# train.py for unitest +import os +env = os.environ.copy() +assert "PADDLE_MASTER" in env +assert "PADDLE_GLOBAL_SIZE" in env +assert "PADDLE_LOCAL_SIZE" in env +assert "PADDLE_GLOBAL_RANK" in env +assert "PADDLE_LOCAL_RANK" in env +''' + +pspyfile = '''# train.py for unitest +import os +env = os.environ.copy() +assert "PADDLE_PSERVERS_IP_PORT_LIST" in env +assert "PADDLE_TRAINER_ENDPOINTS" in env +#assert "PADDLE_PSERVER_ENDPOINTS" in env +#assert "PADDLE_TRAINER_ENDPOINTS" in env +#assert "PADDLE_ROLE" in env +#assert "PADDLE_RANK" in env +''' + + +def write_file(name, ct): + with open(name, "w") as f: + f.write(ct) + + +def get_files(pth, prefix): + return [ + f for f in listdir(pth) if isfile(join(pth, f)) and f.startswith(prefix) + ] + + +class Collective_Test(unittest.TestCase): + def setUp(self): + write_file(pyname, colpyfile) + + def pdrun(self, args, env=None): + cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.run"] + if args: + cmd.extend(args.split(" ")) + cmd.extend([pyname]) + proc = subprocess.Popen(cmd, env) + return proc + + ''' + def test_collective_1(self): + args = "--id test1" + p = self.pdrun(args) + p.wait() + self.assertTrue(p.poll() == 0) + + ''' + + def test_collective_2(self): + if os.path.exists('./log'): + shutil.rmtree('./log') + + args = "--id test2 --devices 0,1,2" + p = self.pdrun(args) + p.wait() + self.assertTrue(p.poll() == 0) + + c = get_files('log', 'test2') + self.assertTrue(len(c) == 4) + + def test_collective_3(self): + if os.path.exists('./log'): + shutil.rmtree('./log') + + port = random.randrange(6000, 8000) + args = "--id test3 --devices 0,1 --master 127.0.0.1:{} --np 2".format( + port) + p1 = self.pdrun(args) + p2 = self.pdrun(args) + p1.wait() + p2.wait() + self.assertTrue(p1.poll() == 0) + self.assertTrue(p2.poll() == 0) + + c = get_files('log', 'test3') + self.assertTrue(len(c) == 6) + + +class PS_Test(unittest.TestCase): + def setUp(self): + write_file(pyname, pspyfile) + + def pdrun(self, args, env=None): + cmd = [sys.executable.split('/')[-1], "-m", "paddle.distributed.run"] + if args: + cmd.extend(args.split(" ")) + cmd.extend([pyname]) + proc = subprocess.Popen(cmd, env) + return proc + + ''' + def test_ps_1(self): + args = "--mode ps" + p = self.pdrun(args) + p.wait() + self.assertTrue(p.poll() == 0) + + def test_ps_2(self): + if os.path.exists('./log'): + shutil.rmtree('./log') + + args = "--id ps2 --server_num=2 --trainer_num=2" + p = self.pdrun(args) + p.wait() + self.assertTrue(p.poll() == 0) + + c = get_files('log', 'ps2') + self.assertTrue(len(c) == 5) + ''' + + def test_ps_3(self): + if os.path.exists('./log'): + shutil.rmtree('./log') + + port = random.randrange(6000, 8000) + args = "--id ps3 --master 127.0.0.1:{} --np 2 --server_num=1 --trainer_num=1".format( + port) + p1 = self.pdrun(args) + p2 = self.pdrun(args) + p1.wait() + p2.wait() + self.assertTrue(p1.poll() == 0) + self.assertTrue(p2.poll() == 0) + + c = get_files('log', 'ps3') + self.assertTrue(len(c) == 6) + + def test_ps_4(self): + if os.path.exists('./log'): + shutil.rmtree('./log') + + args = "--id ps4 --servers 127.0.0.1:8900,127.0.0.1:8901 --trainers 127.0.0.1:8902,127.0.0.1:8903" + p1 = self.pdrun(args) + p1.wait() + self.assertTrue(p1.poll() == 0) + + c = get_files('log', 'ps4') + self.assertTrue(len(c) == 5) + + +if __name__ == '__main__': + unittest.main()