Skip to content

Commit

Permalink
fleet support elastic scale up/down (#36684)
Browse files Browse the repository at this point in the history
* fleet support elastic train

* fleet support elastic train

* support elastic

* add unittest

* fix unitest bug

* fix unittest bug

* fix unittest bug

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix unittest coverage

* fix elastic bug

* fix ci fail

* fix ci fail

* fix elastic bug

* fix elastic bug

* fix joint debugging bug

* fix joint debugging bug

* fix windows ci failed

* fix windows ci failed
  • Loading branch information
xymyeah authored Nov 11, 2021
1 parent 9a9345f commit 6af531b
Show file tree
Hide file tree
Showing 7 changed files with 761 additions and 97 deletions.
7 changes: 5 additions & 2 deletions python/paddle/distributed/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ def close(self):
parser.add_argument(
"--elastic_server", type=str, help="etcd server host:port")
parser.add_argument("--job_id", type=str, help="job unique id")
parser.add_argument("--np", type=int, help="job pod/node number")
parser.add_argument(
"--np",
type=str,
help="job pod/node number, need to be 'MIN' or 'MIN:MAX' format")
parser.add_argument("action", type=str, help="action to take")

args = parser.parse_args()

server = args.elastic_server or os.getenv('PADDLE_ELASTIC_SERVER')
name = args.job_id or os.getenv('PADDLE_ELASTIC_JOB_ID')

np = args.np or int(os.getenv('PADDLE_ELASTIC_NP', 0))
np = int(args.np.split(":")[0]) or int(os.getenv('PADDLE_ELASTIC_NP', 0))

cmd = Command(server, name)

Expand Down
8 changes: 6 additions & 2 deletions python/paddle/distributed/fleet/elastic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@ def enable_elastic(args, distribute_mode):
if not args.job_id and not os.getenv('PADDLE_ELASTIC_JOB_ID'):
return False

if not args.np and not int(os.getenv('PADDLE_ELASTIC_NP', 0)):
if not args.np and not os.getenv('PADDLE_ELASTIC_NP'):
return False

return True


def launch_elastic(args, distribute_mode):

elastic = ElasticManager(args)
server = args.elastic_server or os.getenv('PADDLE_ELASTIC_SERVER')
srv, port = server.split(':')
import etcd3
etcd_client = etcd3.client(host=srv, port=port)
elastic = ElasticManager(args, etcd_client)

signal.signal(signal.SIGTERM, elastic.signal_handler)
signal.signal(signal.SIGABRT, elastic.signal_handler)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/elastic/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
from paddle.distributed.fleet import launch_utils
from paddle.distributed.fleet import cloud_utils
from paddle.distributed.fleet import ascend_utils
Expand Down
Loading

0 comments on commit 6af531b

Please sign in to comment.