Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CustomDevice] use CommContextManager to create xccl comm #57957

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ void ProcessGroupCustom::CreateXCCLEnvCache(const Place& place,
<< ", place: " << place_key;

phi::distributed::CommContextManager::CreateXCCLCommContext(
store_, std::to_string(gid_), place.GetDeviceType(), rank_, size_);
store_, std::to_string(gid_), place, rank_, size_);

auto* calc_ctx = static_cast<phi::CustomContext*>(
platform::DeviceContextPool::Instance().Get(place));
Expand Down
15 changes: 11 additions & 4 deletions paddle/fluid/operators/collective/c_comm_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ class CCommInitOp : public framework::OperatorBase {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input con not be empty."));

phi::ccl::CCLRootId* comm_id = var->GetMutable<phi::ccl::CCLRootId>();

int nranks = Attr<int>("nranks");
int rid = Attr<int>("ring_id");

Expand All @@ -73,8 +71,17 @@ class CCommInitOp : public framework::OperatorBase {
device_id = Attr<int>("device_id");
}
int rank_id = Attr<int>("rank");
platform::XCCLCommContext::Instance(place.GetDeviceType())
.CreateComm(comm_id, nranks, rank_id, device_id, rid);
auto store = phi::distributed::CreateOrGetGlobalTCPStore();
if (!phi::distributed::CommContextManager::GetInstance().Has(
std::to_string(rid))) {
phi::distributed::CommContextManager::CreateXCCLCommContext(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里看下是否需要兼容新旧通信库吧,nccl里是通过FLAGS_dynamic_static_unified_comm这个Flags判断可以做新旧通信库的切换。

124-127的判断develop分支已经更新了,这个pr好像没有提示冲突,可以手动更新下develop,防止覆盖回去了。
(当前flags的判断有问题)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom deivce不需要兼容旧通信库了

store,
std::to_string(rid),
phi::CustomPlace(place.GetDeviceType(), device_id),
rank_id,
nranks,
"c_comm_init_op");
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with custom device."));
Expand Down
29 changes: 1 addition & 28 deletions paddle/fluid/operators/collective/c_gen_xccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,7 @@ class CGenXCCLIdOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
int rank = Attr<int>("rank");
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};

std::string endpoint = Attr<std::string>("endpoint");
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();

std::vector<phi::ccl::CCLRootId> xccl_ids;
xccl_ids.resize(1);

if (rank == 0) {
for (size_t i = 0; i < xccl_ids.size(); ++i) {
phi::DeviceManager::CCLGetUniqueId(dev_place.GetDeviceType(),
&xccl_ids[i]);
}
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &xccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &xccl_ids, ring_id);
}

CopyXCCLIDToVar(xccl_ids, func, scope);
}
const platform::Place& dev_place) const override {}
};

#else
Expand Down
115 changes: 74 additions & 41 deletions paddle/fluid/operators/custom_device_common_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ limitations under the License. */
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/xccl_comm_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"

Expand Down Expand Up @@ -99,13 +101,14 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
phi::distributed::CommContextManager::GetInstance().Get(
std::to_string(rid)));
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
comm->GetSize(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
"nranks: %s should equal to %s", nranks, comm->GetSize()));

int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
Expand All @@ -118,7 +121,7 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel<T> {
recv_buff,
send_numel,
phi::ccl::ToCCLDataType(x->dtype()),
comm->comm(),
comm->GetXcclComm(),
stream);
}
std::vector<phi::DenseTensor> inputs;
Expand Down Expand Up @@ -600,25 +603,25 @@ class CAllReduceOpCustomDeviceKernel : public framework::OpKernel<T> {
return;
}

auto comm =
paddle::platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
phi::distributed::CommContextManager::GetInstance().Get(
std::to_string(rid)));

std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<paddle::platform::CustomDeviceContext*>(dev_ctx)
->GetStream();
} else {
stream = comm->stream();
stream = comm->GetStream();
}
phi::DeviceManager::CCLAllReduce(place.GetDeviceType(),
const_cast<void*>(sendbuff),
recvbuff,
numel,
dtype,
red_type,
comm->comm(),
comm->GetXcclComm(),
*stream);
}
};
Expand All @@ -634,22 +637,30 @@ class CBroadcastOpCustomDeviceKernel : public framework::OpKernel<T> {
int root = ctx.Attr<int>("root");
int rid = ctx.Attr<int>("ring_id");

auto stream = static_cast<const phi::CustomContext&>(ctx.device_context())
.GetStream();
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
phi::distributed::CommContextManager::GetInstance().Get(
std::to_string(rid)));

std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<paddle::platform::CustomDeviceContext*>(dev_ctx)
->GetStream();
} else {
stream = comm->GetStream();
}

int numel = x->numel();
auto dtype = phi::ccl::ToCCLDataType(x->dtype());
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
if (root == comm->rank()) {
if (root == comm->GetRank()) {
phi::DeviceManager::CCLBroadcast(place.GetDeviceType(),
const_cast<void*>(x->data()),
numel,
dtype,
root,
comm->comm(),
comm->GetXcclComm(),
*stream);
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
VLOG(3) << "rank " << comm->GetRank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
framework::TensorCopy(
Expand All @@ -664,9 +675,9 @@ class CBroadcastOpCustomDeviceKernel : public framework::OpKernel<T> {
numel,
dtype,
root,
comm->comm(),
comm->GetXcclComm(),
*stream);
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received "
VLOG(3) << "rank " << comm->GetRank() << " invoke Bcast. received "
<< phi::product(out->dims());
}
out->set_lod(x->lod());
Expand All @@ -684,16 +695,27 @@ class BarrierOpCustomDeviceKernel : public framework::OpKernel<T> {
const void* sendbuff = in->data();
void* recvbuff = ctx.device_context().Alloc<T>(out);
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);

auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
phi::distributed::CommContextManager::GetInstance().Get(
std::to_string(rid)));

std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<paddle::platform::CustomDeviceContext*>(dev_ctx)
->GetStream();
} else {
stream = comm->GetStream();
}
phi::DeviceManager::CCLAllReduce(place.GetDeviceType(),
const_cast<void*>(sendbuff),
recvbuff,
numel,
phi::ccl::ToCCLDataType(in->dtype()),
phi::ccl::CCLReduceOp::SUM,
comm->comm(),
*(comm->stream()));
comm->GetXcclComm(),
*stream);
}
};

Expand Down Expand Up @@ -993,16 +1015,22 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel<T> {
}
}
} else {
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
phi::distributed::CommContextManager::GetInstance().Get(
std::to_string(rid)));

std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
stream = dev_ctx.GetStream();
auto dev_ctx =
paddle::platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<paddle::platform::CustomDeviceContext*>(dev_ctx)
->GetStream();
} else {
stream = comm->stream();
stream = comm->GetStream();
}
int nranks = comm->nranks();
int rank = comm->rank();

int nranks = comm->GetSize();
int rank = comm->GetRank();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
int64_t fwd_count = 0;
Expand Down Expand Up @@ -1033,7 +1061,7 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel<T> {
cpu_global_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
comm->GetXcclComm(),
*stream);
recv_ptr += cpu_global_count_data[idx];
}
Expand All @@ -1049,7 +1077,7 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel<T> {
cpu_local_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
comm->GetXcclComm(),
*stream);
}
}
Expand All @@ -1072,7 +1100,7 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel<T> {
cpu_global_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
comm->GetXcclComm(),
*stream);
recv_ptr += cpu_global_count_data[idx];
}
Expand Down Expand Up @@ -1199,16 +1227,21 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel<T> {
}
}
} else {
auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType())
.Get(rid, place);
auto comm = reinterpret_cast<phi::distributed::XCCLCommContext*>(
phi::distributed::CommContextManager::GetInstance().Get(
std::to_string(rid)));

std::shared_ptr<phi::stream::Stream> stream;
if (ctx.Attr<bool>("use_calc_stream")) {
stream = dev_ctx.GetStream();
auto dev_ctx =
paddle::platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<paddle::platform::CustomDeviceContext*>(dev_ctx)
->GetStream();
} else {
stream = comm->stream();
stream = comm->GetStream();
}
int nranks = comm->nranks();
int rank = comm->rank();
int nranks = comm->GetSize();
int rank = comm->GetRank();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;

Expand Down Expand Up @@ -1238,7 +1271,7 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel<T> {
cpu_local_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
comm->GetXcclComm(),
*stream);
}
}
Expand All @@ -1253,7 +1286,7 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel<T> {
cpu_global_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
comm->GetXcclComm(),
*stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(place)->MemoryCopyD2D(
Expand All @@ -1274,7 +1307,7 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel<T> {
cpu_local_count_data[idx] * in_feat,
phi::ccl::ToCCLDataType(x->dtype()),
j,
comm->comm(),
comm->GetXcclComm(),
*stream);
}
}
Expand Down
6 changes: 1 addition & 5 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
} else if (phi::CustomContext::classof(&dev_ctx)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CommContextManager::CreateXCCLCommContext(
store,
unique_comm_key,
dev_ctx.GetPlace().GetDeviceType(),
rank,
world_size);
store, unique_comm_key, dev_ctx.GetPlace(), rank, world_size);
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down
12 changes: 8 additions & 4 deletions paddle/phi/core/distributed/comm_context_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,19 @@ void CommContextManager::CreateGlooCommContext(
void CommContextManager::CreateXCCLCommContext(
const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
const std::string& device_type,
const phi::Place& place,
int rank,
int size) {
int size,
const std::string& hash_key) {
phi::ccl::CCLRootId xccl_root_id;
if (rank == 0) {
phi::DeviceManager::CCLGetUniqueId(device_type, &xccl_root_id);
phi::DeviceManager::CCLGetUniqueId(place.GetDeviceType(), &xccl_root_id);
}

std::string unique_key = "XCCLCommContext/" + unique_comm_key;
if (!hash_key.empty()) {
unique_key += "/" + hash_key;
}
if (rank == 0) {
store->set(unique_key, xccl_root_id);
} else {
Expand All @@ -148,7 +152,7 @@ void CommContextManager::CreateXCCLCommContext(
<< ", unique_comm_key: " << unique_comm_key << ", xccl uniqueid: "
<< phi::ccl::SerializeXCCLUniqueId(xccl_root_id);
auto xccl_comm_context =
std::make_unique<XCCLCommContext>(device_type, rank, size, xccl_root_id);
std::make_unique<XCCLCommContext>(place, rank, size, xccl_root_id);
auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(unique_comm_key, std::move(xccl_comm_context));
Expand Down
Loading