-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewComm] No.3 compatiable upgrade for global_scatter op #57161
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,12 +13,18 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/collective/global_scatter_op.h" | ||
#include "paddle/phi/core/distributed/comm_context_manager.h" | ||
|
||
#include "paddle/fluid/distributed/collective/utils.h" | ||
#include "paddle/fluid/framework/convert_utils.h" | ||
|
||
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) | ||
#include "paddle/fluid/platform/collective_helper.h" | ||
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" | ||
#include "paddle/phi/core/distributed/nccl_comm_context.h" | ||
#include "paddle/phi/core/flags.h" | ||
PHI_DECLARE_bool(dynamic_static_unified_comm); | ||
#endif | ||
#include "paddle/fluid/framework/convert_utils.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
@@ -78,15 +84,48 @@ struct GlobalScatterFunctor<phi::GPUContext, T> { | |
ring_id)); | ||
|
||
auto place = ctx.GetPlace(); | ||
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); | ||
gpuStream_t stream = nullptr; | ||
platform::NCCLComm* comm = nullptr; | ||
phi::distributed::NCCLCommContext* comm_ctx = nullptr; | ||
int nranks = 0; | ||
|
||
const auto& comm_context_manager = | ||
phi::distributed::CommContextManager::GetInstance(); | ||
|
||
if (FLAGS_dynamic_static_unified_comm) { | ||
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), | ||
true, | ||
platform::errors::InvalidArgument( | ||
"You choose to use new communication library by " | ||
"setting environment " | ||
"variable FLAGS_dynamic_static_unified_comm True. " | ||
"But ring_id(%d) is " | ||
"not found in comm_context_manager.", | ||
std::to_string(ring_id))); | ||
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>( | ||
comm_context_manager.Get(std::to_string(ring_id))); | ||
PADDLE_ENFORCE_NE(comm_ctx, | ||
nullptr, | ||
platform::errors::Unavailable( | ||
"NCCLCommContext is nullptr, collective op should " | ||
"has ring_id attr.")); | ||
|
||
stream = comm_ctx->GetStream(); | ||
nranks = comm_ctx->GetSize(); | ||
VLOG(3) << "new comm_context_manager has ring_id " << ring_id; | ||
} else { // old comm_context | ||
comm = platform::NCCLCommContext::Instance().Get(ring_id, place); | ||
|
||
stream = comm->stream(); | ||
nranks = comm->nranks(); | ||
VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; | ||
} | ||
|
||
if (ctx.Attr<bool>("use_calc_stream")) { | ||
// should ExecutionContext for calc stream. | ||
stream = ctx.cuda_device_context().stream(); | ||
} else { | ||
stream = comm->stream(); | ||
} | ||
int nranks = comm->nranks(); | ||
|
||
auto in_feat = x->dims()[1]; | ||
auto n_expert = local_count->dims()[0] / nranks; | ||
int64_t fwd_count = 0; | ||
|
@@ -103,34 +142,62 @@ struct GlobalScatterFunctor<phi::GPUContext, T> { | |
} | ||
|
||
auto recv_ptr = 0; | ||
auto send_buf = x->data<T>(); | ||
auto recv_buf = out->mutable_data<T>(out_dims, place); | ||
out->mutable_data<T>(out_dims, place); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以删除,不然PR-CI-APPROVE过不了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里我理解out需要申请显存,无法删除。看了PR-CI-APPROVE的报错,是否可以用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯嗯,也是可以的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PR-CI-APPROVAL中提到的 |
||
|
||
for (auto i = 0; i < n_expert; ++i) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); | ||
for (auto j = 0; j < nranks; ++j) { | ||
int idx = i + j * n_expert; | ||
if (cpu_local_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS( | ||
platform::dynload::ncclSend(send_buf + expert_ptr[idx] * in_feat, | ||
cpu_local_count_data[idx] * in_feat, | ||
dtype, | ||
j, | ||
comm->comm(), | ||
stream)); | ||
if (comm_ctx) { | ||
for (auto i = 0; i < n_expert; ++i) { | ||
comm_ctx->GroupStart(); | ||
for (auto j = 0; j < nranks; ++j) { | ||
int idx = i + j * n_expert; | ||
if (cpu_local_count_data[idx]) { | ||
auto send_buf = distributed::GetPartialTensor( | ||
*x, | ||
expert_ptr[idx] * in_feat, | ||
cpu_local_count_data[idx] * in_feat); | ||
|
||
comm_ctx->Send( | ||
send_buf, cpu_local_count_data[idx] * in_feat, j, stream); | ||
} | ||
if (cpu_global_count_data[idx]) { | ||
auto recv_buf = distributed::GetPartialTensor( | ||
*out, recv_ptr * in_feat, cpu_global_count_data[idx] * in_feat); | ||
comm_ctx->Recv( | ||
&recv_buf, cpu_global_count_data[idx] * in_feat, j, stream); | ||
recv_ptr += cpu_global_count_data[idx]; | ||
} | ||
} | ||
if (cpu_global_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS( | ||
platform::dynload::ncclRecv(recv_buf + recv_ptr * in_feat, | ||
cpu_global_count_data[idx] * in_feat, | ||
dtype, | ||
j, | ||
comm->comm(), | ||
stream)); | ||
recv_ptr += cpu_global_count_data[idx]; | ||
comm_ctx->GroupEnd(); | ||
} | ||
} else { | ||
auto send_buf = x->data<T>(); | ||
auto recv_buf = out->data<T>(); | ||
|
||
for (auto i = 0; i < n_expert; ++i) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); | ||
for (auto j = 0; j < nranks; ++j) { | ||
int idx = i + j * n_expert; | ||
if (cpu_local_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( | ||
send_buf + expert_ptr[idx] * in_feat, | ||
cpu_local_count_data[idx] * in_feat, | ||
dtype, | ||
j, | ||
comm->comm(), | ||
stream)); | ||
} | ||
if (cpu_global_count_data[idx]) { | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( | ||
recv_buf + recv_ptr * in_feat, | ||
cpu_global_count_data[idx] * in_feat, | ||
dtype, | ||
j, | ||
comm->comm(), | ||
stream)); | ||
recv_ptr += cpu_global_count_data[idx]; | ||
} | ||
} | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); | ||
} | ||
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); | ||
} | ||
|
||
#else | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
global_scatter_op.h里新增的几个头文件,一起放到这里。因为不是所有情况都会编译NCCL的库,需要放在
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
的宏判断里There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是将头文件中
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
这段宏定义放在.cu.cc下?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件已经有这一段宏了(18行),把global_scatter_op.h新include的几个头文件放到这段宏定义里。PR-CI-PY3挂应该就是这个原因。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件中已经有了这一段,再将其加到.cu.cc中是否冗余了?我看了一些op,有的是加在了头文件有的是加在了.cu.cc,这里是否能做一个统一?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
都放在.cu.cc下面吧,注意用条件宏判断是否include,一些情况下是不会编译NCCL及相关代码的,会导致CI挂掉