Skip to content

Commit

Permalink
[Feature] Build parser to support distributed training (#30658)
Browse files Browse the repository at this point in the history
[Feature] Build parser to support distributed training
  • Loading branch information
void-main authored Jan 25, 2021
1 parent 5b77b25 commit 904cc44
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
import paddle.fluid.core as core
import numpy as np
from . import ascend_parser
from paddle.distributed import fleet
import hccl.manage.api as hccl
from collections import namedtuple

HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])

class AscendIRParser(object):
def __init__(self):
self.graph_idx = 0
self.hcom_endpoints = {}
self.groups_to_create = []

def _construct_input_map(self, input_varlist):
ret_map = {}
Expand All @@ -38,8 +44,37 @@ def _construct_input_map(self, input_varlist):
ret_map[var.name] = ge_input
return ge_in_operator, ret_map

def _endpoint_to_world_rank_id(self, endpoint):
world_endpoints = fleet.worker_endpoints()
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids())
return world_endpoints.index(endpoint)

def parse_op(self, op):
if op.type in ascend_parser.registerd_op:
if op.type == 'c_gen_nccl_id':
endpoint = op.attr("endpoint")
other_endpoints = op.attr("other_endpoints")
rank = op.attr("rank")

nccl_id = op.output_arg_names[0]

# c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
# we should combine these together to produce world_rank_ids
self.hcom_endpoints[nccl_id] = other_endpoints[:]
self.hcom_endpoints[nccl_id].insert(rank, endpoint)

print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id]))
elif op.type == 'c_comm_init':
nccl_id = op.input_arg_names[0]
nranks = op.attr("nranks")
assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count"
rank = op.attr("rank")
ring_id = op.attr("ring_id")

group_name = "hcom_group_" + str(ring_id)
global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]]
self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids))
print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type])
op_parser.apply(op)
Expand Down Expand Up @@ -137,7 +172,9 @@ def minimize(self,
parameter_list=None,
no_grad_set=None,
auto_dp=False):
minimized = self.inner_opt.minimize(loss, startup_program=startup_program)
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(loss, startup_program=startup_program)

self.ascend_instance = core.AscendInstance()

Expand Down Expand Up @@ -172,6 +209,10 @@ def minimize(self,
startup_graph, main_graph = self.parser.parse_program(
startup_program, main_block.program, input_varlist, self.fetch_list)

for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids))

self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
from paddle.distributed import fleet

registerd_op = {
"elementwise_add": "AddParser",
Expand Down Expand Up @@ -555,7 +555,8 @@ def __init__(self, graph, var2geop, reduction):
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.reduction
group = "hccl_world_group" #self.op.attr("group")
ring_id = self.op.attr("ring_id")
group = "hcom_group_" + str(ring_id)
fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id")

Expand Down Expand Up @@ -658,12 +659,13 @@ def _apply(self):
"shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]]


class ScaleParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ScaleParser, self).__init__(graph, var2geop)
self.parser_name = "scale"

def _apply(self):
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
scale = self.op.attr("scale") #self.get_ge_input(self.op.input_arg_names[1])
bias = self.op.attr("bias")
Expand All @@ -672,9 +674,9 @@ def _apply(self):
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", bias)
else:
x_add_bias = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", bias) #set_input("x2", bias)
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0)
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0)
#tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x)
#bias_ = self.create_ge_tensor([1], 5, bias)
#bias_ = self.create_ge_tensor([1], 5, bias)
#const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias)
return [scale_value],[[0]]

Expand All @@ -695,5 +697,7 @@ def _apply(self):
tensor = self._create_ge_tensor([len(shape)], 2, shape)
const_shape = core.GEOperatorFactory.create_operator("shape" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor)
reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis)

return [reshape, reshape], [[0],[1]]


27 changes: 23 additions & 4 deletions python/paddle/fluid/tests/unittests/ascend_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.ascend import ascend_parser, ascend_optimizer
from collections import namedtuple

Block = namedtuple('Block', ['program'])
Loss = namedtuple('Loss', ['block'])

paddle.enable_static()

Expand Down Expand Up @@ -63,10 +68,6 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
block.create_var(
name="data",
persistable=True,
dtype='float32')

with fluid.program_guard(main_program):
op_type="c_allreduce_sum"
Expand All @@ -79,6 +80,9 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints
attrs={'ring_id': ring_id,
'use_calc_stream': True})

print("startup program:", startup_program)
print("main program:", main_program)

def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
startup_programs=[]
main_programs=[]
Expand All @@ -89,6 +93,7 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
groups[0]=[trainer_endpoints[0], trainer_endpoints[1]]
groups[1]=[trainer_endpoints[2], trainer_endpoints[3]]
groups[2]=[trainer_endpoints[0], trainer_endpoints[2]]
print("groups:", groups)

for i in range(len(trainer_endpoints)):
startup_programs.append(fluid.Program())
Expand All @@ -105,6 +110,20 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
print(startup_programs[local_rank])
print(main_programs[local_rank])

print("local rank: ", local_rank)
print("local startup program: ", startup_programs[local_rank])

startup_program = startup_programs[local_rank]
main_program = main_programs[local_rank]
loss = Loss(Block(main_program))
optimizer = ascend_optimizer.AscendOptimizer(None, fetch_list=[])
optimizer.minimize(loss, startup_program, auto_dp=True)

exe = paddle.static.Executor(paddle.CPUPlace())
#exe.run(startup_program)
exe.run(main_program)


worker_endpoints=fleet.worker_endpoints()
world_device_ids=fleet.world_device_ids()
local_device_ids=fleet.local_device_ids()
Expand Down

0 comments on commit 904cc44

Please sign in to comment.