Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewComm] No.3 compatiable upgrade for global_scatter op #57161

Merged
merged 3 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 96 additions & 29 deletions paddle/fluid/operators/collective/global_scatter_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/global_scatter_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"

#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/framework/convert_utils.h"

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_scatter_op.h里新增的几个头文件,一起放到这里。因为不是所有情况都会编译NCCL的库,需要放在#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)的宏判断里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是将头文件中#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)这段宏定义放在.cu.cc下?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件已经有这一段宏了(18行),把global_scatter_op.h新include的几个头文件放到这段宏定义里。PR-CI-PY3挂应该就是这个原因。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件中已经有了这一段,再将其加到.cu.cc中是否冗余了?我看了一些op,有的是加在了头文件有的是加在了.cu.cc,这里是否能做一个统一?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

都放在.cu.cc下面吧,注意用条件宏判断是否include,一些情况下是不会编译NCCL及相关代码的,会导致CI挂掉

#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/framework/convert_utils.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -78,15 +84,48 @@ struct GlobalScatterFunctor<phi::GPUContext, T> {
ring_id));

auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
int nranks = 0;

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));

stream = comm_ctx->GetStream();
nranks = comm_ctx->GetSize();
VLOG(3) << "new comm_context_manager has ring_id " << ring_id;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);

stream = comm->stream();
nranks = comm->nranks();
VLOG(3) << "old NCCLCommContext has ring_id " << ring_id;
}

if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
int nranks = comm->nranks();

auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
int64_t fwd_count = 0;
Expand All @@ -103,34 +142,62 @@ struct GlobalScatterFunctor<phi::GPUContext, T> {
}

auto recv_ptr = 0;
auto send_buf = x->data<T>();
auto recv_buf = out->mutable_data<T>(out_dims, place);
out->mutable_data<T>(out_dims, place);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以删除,不然PR-CI-APPROVE过不了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我理解out需要申请显存,无法删除。看了PR-CI-APPROVE的报错,是否可以用phi::DeviceContext::Alloc()代替

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,也是可以的

Copy link
Contributor Author

@BeingGod BeingGod Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR-CI-APPROVAL中提到的phi::DeviceContext::Alloc()应该是适用于phi算子,但是现在global_scatter还是旧的算子体系,所以还是得用mutable_data。申请一下豁免合入~


for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
if (comm_ctx) {
for (auto i = 0; i < n_expert; ++i) {
comm_ctx->GroupStart();
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
auto send_buf = distributed::GetPartialTensor(
*x,
expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);

comm_ctx->Send(
send_buf, cpu_local_count_data[idx] * in_feat, j, stream);
}
if (cpu_global_count_data[idx]) {
auto recv_buf = distributed::GetPartialTensor(
*out, recv_ptr * in_feat, cpu_global_count_data[idx] * in_feat);
comm_ctx->Recv(
&recv_buf, cpu_global_count_data[idx] * in_feat, j, stream);
recv_ptr += cpu_global_count_data[idx];
}
}
if (cpu_global_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
recv_ptr += cpu_global_count_data[idx];
comm_ctx->GroupEnd();
}
} else {
auto send_buf = x->data<T>();
auto recv_buf = out->data<T>();

for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
send_buf + expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
}
if (cpu_global_count_data[idx]) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
recv_buf + recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat,
dtype,
j,
comm->comm(),
stream));
recv_ptr += cpu_global_count_data[idx];
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}

#else
Expand Down
35 changes: 33 additions & 2 deletions test/collective/collective_global_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,41 @@ def get_model(self, main_prog, startup_program, rank, indata=None):
)
return [output]

def get_model_new_comm(self, main_prog, startup_program, rank, indata=None):
with base.program_guard(main_prog, startup_program):
seed = os.getpid()
np.random.seed(seed)
in_feat = 2
n_expert = 2
world_size = 2
tot_expert = n_expert * world_size
local_input_buf = paddle.static.data(
name="local_input_buf", shape=[-1, in_feat], dtype="float32"
)
local_expert_count = paddle.static.data(
name="local_expert_count", shape=[tot_expert], dtype="int64"
)
global_expert_count = []
paddle.distributed.alltoall(
paddle.split(local_expert_count, 2, axis=0), global_expert_count
)
global_expert_count = paddle.concat(global_expert_count, axis=0)
output = moe_utils.global_scatter(
local_input_buf, local_expert_count, global_expert_count
)
return [output]

def run_trainer(self, args):
train_prog = base.Program()
startup_prog = base.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
paddle.distributed.init_parallel_env()
if args["dynamic_static_unified_comm"]:
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = base.CUDAPlace(
Expand All @@ -87,7 +114,11 @@ def run_trainer(self, args):
"float32"
)
if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank)
result = (
self.get_model_new_comm(train_prog, startup_prog, rank)
if args["dynamic_static_unified_comm"]
else self.get_model(train_prog, startup_prog, rank)
)
exe = base.Executor(place)
exe.run(startup_prog)
fetch_list = []
Expand Down
8 changes: 8 additions & 0 deletions test/collective/test_collective_global_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ def test_global_scatter_nccl_dygraph_eager(self):
eager_mode=True,
)

def test_global_scatter_nccl_new_comm(self):
self.check_with_place(
"collective_global_scatter.py",
"global_scatter",
"nccl",
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)


if __name__ == '__main__':
unittest.main()