Skip to content

Commit

Permalink
adjust comm init for static graph (PaddlePaddle#57169)
Browse files Browse the repository at this point in the history
* fix conflicts

* fix bkcl compile

* format code style

* fix ut

* fix conflicts with develop

* update

* update
  • Loading branch information
wentaoyu authored Sep 19, 2023
1 parent 8ed0ea8 commit a9956d9
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 13 deletions.
7 changes: 7 additions & 0 deletions paddle/fluid/distributed/fleet_executor/dist_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
ss << ep << ", ";
}
VLOG(3) << ss.str();
std::string endpoints_str = config_.current_endpoint;
for (const auto &peer : peer_endpoints) {
endpoints_str += "," + peer;
}
if (config_.place == "GPU") {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
Expand All @@ -319,6 +323,7 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("endpoints", endpoints_str);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
Expand All @@ -342,6 +347,7 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("endpoints", endpoints_str);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
Expand All @@ -365,6 +371,7 @@ void DistModel::InsertCommOp(std::string tmp_var_name,
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("endpoints", endpoints_str);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,10 @@ void AnalysisPredictor::InsertCommOp(
ss << ep << ", ";
}
VLOG(3) << ss.str();
std::string endpoints_str = config_.dist_config().current_endpoint();
for (const auto &peer : peer_endpoints) {
endpoints_str += "," + peer;
}
if (config_.use_gpu()) {
framework::VarDesc *new_var = block->Var(tmp_var_name);
new_var->SetType(framework::proto::VarType::RAW);
Expand All @@ -859,6 +863,7 @@ void AnalysisPredictor::InsertCommOp(
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("endpoints", endpoints_str);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
Expand All @@ -883,6 +888,7 @@ void AnalysisPredictor::InsertCommOp(
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("endpoints", endpoints_str);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
Expand All @@ -907,6 +913,7 @@ void AnalysisPredictor::InsertCommOp(
comm_init_op->SetAttr("rank", rank);
comm_init_op->SetAttr("nranks", nranks);
comm_init_op->SetAttr("ring_id", ring_id);
comm_init_op->SetAttr("endpoints", endpoints_str);
comm_init_op->SetAttr("op_role",
static_cast<int>(framework::OpRole::kForward));
comm_init_op->CheckAttrs();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ cc_test_old(
DEPS
infer_io_utils
fleet_executor
parallel_executor
python)

if(WITH_ONNXRUNTIME AND WIN32)
Expand Down
34 changes: 32 additions & 2 deletions paddle/fluid/operators/collective/c_comm_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h"
#endif

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif

#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace paddle {
namespace framework {
class Scope;
Expand Down Expand Up @@ -95,8 +104,6 @@ class CCommInitOp : public framework::OperatorBase {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));

UniqueId* comm_id = var->GetMutable<UniqueId>();

int nranks = Attr<int>("nranks");
int rid = Attr<int>("ring_id");

Expand All @@ -105,6 +112,25 @@ class CCommInitOp : public framework::OperatorBase {
device_id = Attr<int>("device_id");
}
int rank_id = Attr<int>("rank");
#endif
#if defined(PADDLE_WITH_NCCL)
const char* dynamic_static_unified_comm =
getenv("FLAGS_dynamic_static_unified_comm");
if (dynamic_static_unified_comm &&
std::string(dynamic_static_unified_comm) == "1") {
VLOG(3) << "#### use new comm lab ####";
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
phi::distributed::CommContextManager::SetDeviceId(device_id);
std::string endpoints = Attr<std::string>("endpoints");
phi::distributed::CommContextManager::CreateNCCLCommContext(
store, std::to_string(rid), rank_id, nranks, endpoints);
return;
}
#endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
VLOG(3) << "#### use old comm lab ####";
UniqueId* comm_id = var->GetMutable<UniqueId>();
CommContext::Instance().CreateComm(
comm_id, nranks, rank_id, device_id, rid);
#endif
Expand All @@ -131,6 +157,10 @@ Initialize collective communication context within this trainer
.SetDefault(-1);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
AddAttr<std::string>("endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault("");
}
};

Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,23 @@ class CGenNCCLIdOp : public framework::OperatorBase {
};

std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();

std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);

if (rank == 0) {
GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
const char* dynamic_static_unified_comm =
getenv("FLAGS_dynamic_static_unified_comm");
if (!dynamic_static_unified_comm ||
std::string(dynamic_static_unified_comm) != "1") {
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
if (rank == 0) {
GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
}
}

CopyNCCLIDToVar(nccl_ids, func, scope);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
comm->comm(),
stream));
}

// step 4, obtain exp(logit)
eigen_softmax.device(*dev_ctx.eigen_device()) = eigen_softmax.exp();

// step 5, obtain sum_exp_logits
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/distributed/fleet/meta_optimizers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
from paddle.framework import core
Expand Down Expand Up @@ -91,13 +92,16 @@ def _init_communicator(
):
# if current_endpoint is None, it means just for sync,
# no group is created.
endpoints_str = ",".join(endpoints)
if current_endpoint:
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)

if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
use_new_comm = os.getenv("FLAGS_dynamic_static_unified_comm", "0")
if use_new_comm not in [1, "1", "True", "true"]:
wait_server_ready(other_endpoints)

def _add_sync_by_allreduce(block):
sync_var = block.create_var(
Expand Down Expand Up @@ -168,6 +172,7 @@ def _add_sync_by_allreduce(block):
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
OP_ROLE_KEY: OpRole.Forward,
},
)
Expand All @@ -192,6 +197,7 @@ def _add_sync_by_allreduce(block):
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
OP_ROLE_KEY: OpRole.Forward,
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,9 @@ def minimize_impl(
self._recreate_not_persist_param_as_var()

self._dump_program_for_debug()
self._wait()
use_new_comm = os.getenv("FLAGS_dynamic_static_unified_comm", "0")
if use_new_comm not in ["1", "True", "true"]:
self._wait()
return optimize_ops, params_grads

def _init_pair_comm(self, pair, ring_id):
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/ps/utils/collective_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _init_communicator(
wait_port,
has_multitrainer=False,
):
endpoints_str = ",".join(endpoints)
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
Expand Down Expand Up @@ -161,6 +162,7 @@ def _init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
self.op_role_key: OpRole.Forward,
},
)
Expand Down Expand Up @@ -190,6 +192,7 @@ def _init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
self.op_role_key: OpRole.Forward,
},
)
Expand Down Expand Up @@ -234,6 +237,7 @@ def _init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
self.op_role_key: OpRole.Forward,
},
)
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/distributed/transpiler/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _init_communicator(
wait_port,
has_multitrainer=False,
):
endpoints_str = ",".join(endpoints)
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
Expand Down Expand Up @@ -158,6 +159,7 @@ def _init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
self.op_role_key: OpRole.Forward,
},
)
Expand Down Expand Up @@ -198,6 +200,7 @@ def _init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
self.op_role_key: OpRole.Forward,
},
)
Expand Down Expand Up @@ -229,6 +232,7 @@ def _init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
'endpoints': endpoints_str,
self.op_role_key: OpRole.Forward,
},
)
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def init_communicator(
):
if nranks < 2:
return
endpoints_str = ",".join(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
block = program.global_block()
Expand Down Expand Up @@ -153,6 +154,7 @@ def init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': 0,
'endpoints': endpoints_str,
},
)
elif core.is_compiled_with_xpu():
Expand Down Expand Up @@ -181,6 +183,7 @@ def init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': 0,
'endpoints': endpoints_str,
},
)
elif (
Expand Down Expand Up @@ -212,6 +215,7 @@ def init_communicator(
'nranks': nranks,
'rank': rank,
'ring_id': 0,
'endpoints': endpoints_str,
},
)

Expand Down
7 changes: 6 additions & 1 deletion python/paddle/incubate/optimizer/distributed_fused_lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def init_communicator(block, rank, ranks, ring_id):
type='c_comm_init',
inputs={'X': comm_id_var},
outputs={},
attrs={'nranks': len(ranks), 'rank': local_rank, 'ring_id': ring_id},
attrs={
'nranks': len(ranks),
'rank': local_rank,
'ring_id': ring_id,
'endpoints': ','.join(eps),
},
)
tmp_var = block.create_var(name=unique_name.generate('tmp'))
block.append_op(
Expand Down

0 comments on commit a9956d9

Please sign in to comment.