diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index fdec898edbe91..a672f5ac99aa8 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -21,6 +21,7 @@ #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" @@ -28,6 +29,14 @@ #include "paddle/phi/kernels/funcs/tensor_to_string.h" #include "paddle/utils/optional.h" +#include "paddle/fluid/distributed/collective/utils.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/flags.h" +PHI_DECLARE_bool(dynamic_static_unified_comm); +#endif + #ifdef __NVCC__ #include "cub/cub.cuh" #include "math.h" // NOLINT @@ -48,6 +57,19 @@ using MasterT = typename phi::dtype::MPTypeTrait::Type; using phi::funcs::FlattenToString; using phi::funcs::ToVector; +static void CheckCommContextHasRingId( + const distributed::CommContextManager &comm_context_manager, int ring_id) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + paddle::platform::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); +} + template static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { static_assert(!std::is_same::value, "T cannot be void."); @@ -875,24 +897,68 @@ static void MultiTensorUpdateLambParamAndBetaPows( } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype, - ncclComm_t comm, - const void *scale, - ncclRedOp_t *op) { +static bool CreatePreMulScaleOpIfSupported( + ncclDataType_t dtype, + ncclComm_t comm, + const void *scale, + ncclRedOp_t *op, + distributed::NCCLCommContext *comm_ctx = nullptr) { #if NCCL_VERSION_CODE >= 21100 - int ver; - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&ver)); - if (ver >= 21100) { - VLOG(10) << "ncclRedOpCreatePreMulSum is supported."; - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum( - op, const_cast(scale), dtype, ncclScalarDevice, comm)); - return true; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_NOT_NULL( + comm_ctx, + phi::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But parameter of comm_ctx should not be nullptr.")); + int ver = comm_ctx->GetNcclVersion(); + if (ver >= 21100) { + VLOG(10) << "ncclRedOpCreatePreMulSum is supported."; + comm_ctx->RedOpCreatePreMulSum( + op, const_cast(scale), dtype, ncclScalarDevice); + return true; + } + } else { + int ver; + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&ver)); + if (ver >= 21100) { + VLOG(10) << "ncclRedOpCreatePreMulSum is supported."; + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum( + op, const_cast(scale), dtype, ncclScalarDevice, comm)); + return true; + } } #endif VLOG(10) << "ncclRedOpCreatePreMulSum is not supported."; return false; } +static void DestoryOpIfSupported( + ncclRedOp_t op, + ncclComm_t comm, + distributed::NCCLCommContext *comm_ctx = nullptr) { +#if NCCL_VERSION_CODE >= 21100 + VLOG(10) << "ncclRedOpDestroy starts"; + + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_NOT_NULL( + comm_ctx, + phi::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But parameter of comm_ctx should not be nullptr.")); + comm_ctx->RedOpDestroy(op); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, comm)); + } + VLOG(10) << "ncclRedOpDestroy ends"; + +#endif + VLOG(10) << "ncclRedOpDestroy is not supported."; +} + template static void LaunchScaleKernel(const phi::GPUContext &dev_ctx, const T1 *x, @@ -922,7 +988,18 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ncclComm_t comm, gpuStream_t stream, const phi::GPUContext &dev_ctx, + distributed::NCCLCommContext *comm_ctx, const T *scale = nullptr) { + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_NOT_NULL( + comm_ctx, + phi::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But parameter of comm_ctx should not be nullptr.")); + } + static_assert( std::is_same::value || std::is_same::value, "T must be either float32 or float16."); @@ -943,8 +1020,8 @@ static void NCCLSumWithScaleBase(const T *sendbuff, ncclRedOp_t op = ncclSum; ncclDataType_t dtype = std::is_same::value ? ncclFloat32 : ncclFloat16; - bool should_destroy_op = - scale && CreatePreMulScaleOpIfSupported(dtype, comm, scale, &op); + bool should_destroy_op = scale && CreatePreMulScaleOpIfSupported( + dtype, comm, scale, &op, comm_ctx); memory_utils::Buffer buffer(dev_ctx.GetPlace()); if (scale && !should_destroy_op) { T *new_sendbuff = buffer.Alloc(numel); @@ -952,21 +1029,44 @@ static void NCCLSumWithScaleBase(const T *sendbuff, sendbuff = new_sendbuff; } - if (UseReduceScatter) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduceScatter( - sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + if (comm_ctx) { + // Here assume comm_ctx->GetNcclComm() have higher priority than comm + if (UseReduceScatter) { + // TODO(BeingGod): NCCLCommContext::ReduceScatter only accept DenseTensor, + // but sendbuff or recvbuff maybe allocated by Buffer. + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclReduceScatter(sendbuff, + recvbuff, + recvcount, + dtype, + op, + comm_ctx->GetNcclComm(), + stream)); + } else { + // TODO(BeingGod): NCCLCommContext::AllReduce only accept DenseTensor, + // but sendbuff or recvbuff maybe allocated by Buffer. + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(sendbuff, + recvbuff, + recvcount, + dtype, + op, + comm_ctx->GetNcclComm(), + stream)); + } } else { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( - sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + if (UseReduceScatter) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduceScatter( + sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( + sendbuff, recvbuff, recvcount, dtype, op, comm, stream)); + } } -#if NCCL_VERSION_CODE >= 21100 if (should_destroy_op) { - VLOG(10) << "ncclRedOpDestroy starts"; - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, comm)); - VLOG(10) << "ncclRedOpDestroy ends"; + DestoryOpIfSupported(op, comm, comm_ctx); } -#endif } template @@ -977,9 +1077,17 @@ static void NCCLReduceScatterWithScale(const T *sendbuff, ncclComm_t comm, gpuStream_t stream, const phi::GPUContext &dev_ctx, + distributed::NCCLCommContext *comm_ctx, const T *scale = nullptr) { - NCCLSumWithScaleBase( - sendbuff, recvbuff, recvcount, nranks, comm, stream, dev_ctx, scale); + NCCLSumWithScaleBase(sendbuff, + recvbuff, + recvcount, + nranks, + comm, + stream, + dev_ctx, + comm_ctx, + scale); } template @@ -990,9 +1098,17 @@ static void NCCLAllReduceWithScale(const T *sendbuff, ncclComm_t comm, gpuStream_t stream, const phi::GPUContext &dev_ctx, + distributed::NCCLCommContext *comm_ctx, const T *scale = nullptr) { - NCCLSumWithScaleBase( - sendbuff, recvbuff, recvcount, nranks, comm, stream, dev_ctx, scale); + NCCLSumWithScaleBase(sendbuff, + recvbuff, + recvcount, + nranks, + comm, + stream, + dev_ctx, + comm_ctx, + scale); } #endif @@ -1643,26 +1759,71 @@ void DistributedFusedLambKernel( int64_t global_rank = 0, local_rank = 0; ncclComm_t global_comm = nullptr, local_comm = nullptr, external_comm = nullptr; - if (nranks > 1) { - auto *nccl_comm_handle = - paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place); - global_comm = nccl_comm_handle->comm(); - global_rank = nccl_comm_handle->rank(); + paddle::platform::NCCLComm *nccl_comm_handle = nullptr, + *local_nccl_comm_handle = nullptr; + distributed::NCCLCommContext *comm_ctx = nullptr, *local_comm_ctx = nullptr, + *external_comm_ctx = nullptr; + + const auto &comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + if (FLAGS_dynamic_static_unified_comm) { + CheckCommContextHasRingId(comm_context_manager, ring_ids[0]); + + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_ids[0]))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + paddle::platform::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + + global_comm = comm_ctx->GetNcclComm(); + global_rank = comm_ctx->GetRank(); if (local_shard) { - auto *local_nccl_comm_handle = - paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place); - local_comm = local_nccl_comm_handle->comm(); - local_rank = local_nccl_comm_handle->rank(); + CheckCommContextHasRingId(comm_context_manager, ring_ids[1]); + + local_comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_ids[1]))); + local_comm = local_comm_ctx->GetNcclComm(); + local_rank = local_comm_ctx->GetRank(); if (use_hierarchical_allreduce) { - external_comm = paddle::platform::NCCLCommContext::Instance() - .Get(ring_ids[2], place) - ->comm(); + CheckCommContextHasRingId(comm_context_manager, ring_ids[2]); + + external_comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_ids[2]))); + external_comm = external_comm_ctx->GetNcclComm(); } } else { local_comm = global_comm; local_rank = global_rank; } + + VLOG(3) << "new comm_context_manager has ring_id " << ring_ids[0]; + } else { + if (nranks > 1) { + nccl_comm_handle = + paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place); + global_comm = nccl_comm_handle->comm(); + global_rank = nccl_comm_handle->rank(); + if (local_shard) { + local_nccl_comm_handle = + paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], + place); + local_comm = local_nccl_comm_handle->comm(); + local_rank = local_nccl_comm_handle->rank(); + if (use_hierarchical_allreduce) { + external_comm = paddle::platform::NCCLCommContext::Instance() + .Get(ring_ids[2], place) + ->comm(); + } + } else { + local_comm = global_comm; + local_rank = global_rank; + } + } } + memory_utils::Buffer grad_norm_square_buffer(place); auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc(2); memory_utils::Buffer cub_tmp_buffer(place); @@ -1715,7 +1876,8 @@ void DistributedFusedLambKernel( num_devices, local_comm, stream, - dev_ctx); + dev_ctx, + local_comm_ctx); NCCLAllReduceWithScale( fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device, @@ -1723,7 +1885,8 @@ void DistributedFusedLambKernel( nranks / num_devices, external_comm, stream, - dev_ctx); + dev_ctx, + external_comm_ctx); NCCLReduceScatterWithScale( fp16_grad_data, @@ -1732,7 +1895,8 @@ void DistributedFusedLambKernel( num_devices, local_comm, stream, - dev_ctx); + dev_ctx, + local_comm_ctx); NCCLAllReduceWithScale( fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device, @@ -1740,7 +1904,8 @@ void DistributedFusedLambKernel( nranks / num_devices, external_comm, stream, - dev_ctx); + dev_ctx, + external_comm_ctx); } else { NCCLAllReduceWithScale(fp32_grad_data, fp32_sum_grad, @@ -1748,14 +1913,16 @@ void DistributedFusedLambKernel( nranks, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); NCCLAllReduceWithScale(fp16_grad_data, fp16_sum_grad, fp16_numel, nranks, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); } fp32_sum_grad += (local_rank * fp32_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device); @@ -1766,14 +1933,16 @@ void DistributedFusedLambKernel( nranks, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); NCCLReduceScatterWithScale(fp16_grad_data, fp16_sum_grad, fp16_numel_each_device, nranks, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); } // (2) Calculate the global grad norm GetSquareGradNorm(fp32_sum_grad, @@ -1786,6 +1955,8 @@ void DistributedFusedLambKernel( VLOG(1) << "Grad square norm before all reduce: " << FlattenToString(fp32_square_grad_norm, 1, place); if (num_devices > 1) { + // TODO(BeingGod): NCCLCommContext::AllReduce only accept DenseTensor, + // but fp32_square_grad_norm is allocated by Buffer. PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclAllReduce(fp32_square_grad_norm, fp32_square_grad_norm, @@ -1852,6 +2023,7 @@ void DistributedFusedLambKernel( local_comm, stream, dev_ctx, + local_comm_ctx, fp32_scale); NCCLAllReduceWithScale( fp32_sum_grad + local_rank * fp32_numel_each_device, @@ -1860,8 +2032,8 @@ void DistributedFusedLambKernel( nranks / num_devices, external_comm, stream, - dev_ctx); - + dev_ctx, + external_comm_ctx); NCCLReduceScatterWithScale( fp16_grad_data, fp16_sum_grad + local_rank * fp16_numel_each_device, @@ -1870,6 +2042,7 @@ void DistributedFusedLambKernel( local_comm, stream, dev_ctx, + local_comm_ctx, fp16_scale); NCCLAllReduceWithScale( fp16_sum_grad + local_rank * fp16_numel_each_device, @@ -1878,7 +2051,8 @@ void DistributedFusedLambKernel( nranks / num_devices, external_comm, stream, - dev_ctx); + dev_ctx, + external_comm_ctx); } else { NCCLAllReduceWithScale(fp32_grad_data, fp32_sum_grad, @@ -1887,6 +2061,7 @@ void DistributedFusedLambKernel( global_comm, stream, dev_ctx, + comm_ctx, fp32_scale); NCCLAllReduceWithScale(fp16_grad_data, fp16_sum_grad, @@ -1895,6 +2070,7 @@ void DistributedFusedLambKernel( global_comm, stream, dev_ctx, + comm_ctx, fp16_scale); } fp32_sum_grad += (local_rank * fp32_numel_each_device); @@ -1907,6 +2083,7 @@ void DistributedFusedLambKernel( global_comm, stream, dev_ctx, + comm_ctx, fp32_scale); NCCLReduceScatterWithScale(fp16_grad_data, fp16_sum_grad, @@ -1915,6 +2092,7 @@ void DistributedFusedLambKernel( global_comm, stream, dev_ctx, + comm_ctx, fp16_scale); } VLOG(1) << "FP32 HasNanInf after all reduce: " @@ -1929,6 +2107,8 @@ void DistributedFusedLambKernel( stream, &cub_tmp_buffer); if (num_devices > 1) { + // TODO(BeingGod): NCCLCommContext::AllReduce only accept DenseTensor, + // but fp32_square_grad_norm is allocated by Buffer. PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclAllReduce(fp32_square_grad_norm, fp32_square_grad_norm, @@ -1954,7 +2134,8 @@ void DistributedFusedLambKernel( num_devices, local_comm, stream, - dev_ctx); + dev_ctx, + local_comm_ctx); NCCLAllReduceWithScale( fp32_sum_grad + local_rank * fp32_numel_each_device, fp32_sum_grad + local_rank * fp32_numel_each_device, @@ -1962,7 +2143,8 @@ void DistributedFusedLambKernel( nranks / num_devices, external_comm, stream, - dev_ctx); + dev_ctx, + external_comm_ctx); NCCLReduceScatterWithScale( fp16_grad_data, fp16_sum_grad + local_rank * fp16_numel_each_device, @@ -1970,7 +2152,8 @@ void DistributedFusedLambKernel( num_devices, local_comm, stream, - dev_ctx); + dev_ctx, + local_comm_ctx); NCCLAllReduceWithScale( fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device, @@ -1978,7 +2161,8 @@ void DistributedFusedLambKernel( nranks / num_devices, external_comm, stream, - dev_ctx); + dev_ctx, + external_comm_ctx); } else { NCCLAllReduceWithScale(fp32_grad_data, fp32_sum_grad, @@ -1986,14 +2170,16 @@ void DistributedFusedLambKernel( nranks, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); NCCLAllReduceWithScale(fp16_grad_data, fp16_sum_grad, fp16_numel, nranks, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); } fp32_sum_grad += (local_rank * fp32_numel_each_device); fp16_sum_grad += (local_rank * fp16_numel_each_device); @@ -2004,14 +2190,16 @@ void DistributedFusedLambKernel( num_devices, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); NCCLReduceScatterWithScale(fp16_grad_data, fp16_sum_grad, fp16_numel_each_device, num_devices, global_comm, stream, - dev_ctx); + dev_ctx, + comm_ctx); } CheckHasNanInfGrad(fp32_sum_grad, fp32_numel_each_device, @@ -2021,6 +2209,8 @@ void DistributedFusedLambKernel( stream, &cub_tmp_buffer); if (num_devices > 1) { + // TODO(BeingGod): NCCLCommContext::AllReduce only accept DenseTensor, + // but fp32_square_grad_norm is allocated by Buffer. PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclAllReduce(fp32_square_grad_norm, fp32_square_grad_norm, @@ -2165,6 +2355,8 @@ void DistributedFusedLambKernel( << FlattenToString(trust_ratio_div_square_norm, param_num, place); if (num_devices > 1) { if (use_master_param_norm) { + // TODO(BeingGod): NCCLCommContext::AllReduce only accept DenseTensor, + // but param_square_norm is allocated by Buffer. PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclAllReduce(param_square_norm + fp32_global_param_num, param_square_norm + fp32_global_param_num, @@ -2174,6 +2366,8 @@ void DistributedFusedLambKernel( local_comm, stream)); } else { + // TODO(BeingGod): NCCLCommContext::AllReduce only accept DenseTensor, + // but trust_ratio_div_square_norm is allocated by Buffer. PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclAllReduce(trust_ratio_div_square_norm, trust_ratio_div_square_norm, @@ -2209,13 +2403,21 @@ void DistributedFusedLambKernel( beta2); if (num_devices > 1) { // ncclAllGather - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllGather(fp32_param_data + fp32_offset, - fp32_param_data, - fp32_numel_each_device, - ncclFloat32, - local_comm, - stream)); + if (local_comm_ctx) { + auto send_buf = paddle::distributed::GetPartialTensor( + *fp32_param_out, fp32_offset, fp32_numel_each_device); + auto recv_buf = paddle::distributed::GetPartialTensor( + *fp32_param_out, 0, fp32_numel_each_device); + local_comm_ctx->AllGather(&recv_buf, send_buf, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllGather(fp32_param_data + fp32_offset, + fp32_param_data, + fp32_numel_each_device, + ncclFloat32, + local_comm, + stream)); + } } beta1_pow_data = nullptr; @@ -2239,13 +2441,21 @@ void DistributedFusedLambKernel( beta2); if (num_devices > 1) { // ncclAllGather - PADDLE_ENFORCE_GPU_SUCCESS( - phi::dynload::ncclAllGather(fp16_param_data + fp16_offset, - fp16_param_data, - fp16_numel_each_device, - ncclFloat16, - local_comm, - stream)); + if (local_comm_ctx) { + auto send_buf = paddle::distributed::GetPartialTensor( + *fp16_param_out, fp16_offset, fp16_numel_each_device); + auto recv_buf = paddle::distributed::GetPartialTensor( + *fp16_param_out, 0, fp16_numel_each_device); + local_comm_ctx->AllGather(&recv_buf, send_buf, stream); + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllGather(fp16_param_data + fp16_offset, + fp16_param_data, + fp16_numel_each_device, + ncclFloat16, + local_comm, + stream)); + } } } VLOG(10) << "Update Param done"; diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index 90b6a4c447c92..faf29add30d91 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -33,8 +33,11 @@ NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) : CommContext(rank, size) { PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetVersion(&nccl_version_)); } +int NCCLCommContext::GetNcclVersion() { return nccl_version_; } + ncclComm_t NCCLCommContext::GetNcclComm() { return nccl_comm_; } gpuStream_t NCCLCommContext::GetStream() { return dev_ctx_->stream(); } @@ -228,5 +231,17 @@ void NCCLCommContext::GroupStart() { } void NCCLCommContext::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); } +void NCCLCommContext::RedOpCreatePreMulSum(ncclRedOp_t* op, + void* scalar, + ncclDataType_t dtype, + ncclScalarResidence_t residence) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpCreatePreMulSum( + op, scalar, dtype, residence, nccl_comm_)); +} + +void NCCLCommContext::RedOpDestroy(ncclRedOp_t op) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRedOpDestroy(op, nccl_comm_)); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.h b/paddle/phi/core/distributed/nccl_comm_context.h index fdd45793a6387..61c3fb06c0e33 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.h +++ b/paddle/phi/core/distributed/nccl_comm_context.h @@ -40,7 +40,9 @@ namespace distributed { class NCCLCommContext final : public CommContext { public: NCCLCommContext(int rank, int size, ncclUniqueId nccl_id); - ~NCCLCommContext() {} + ~NCCLCommContext() override = default; + + int GetNcclVersion(); ncclComm_t GetNcclComm(); @@ -65,6 +67,7 @@ class NCCLCommContext final : public CommContext { const phi::DenseTensor& in_tensor, int root, gpuStream_t stream); + void Send(const phi::DenseTensor& in_tensor, const int64_t& count, const int& peer, @@ -95,6 +98,13 @@ class NCCLCommContext final : public CommContext { int root, gpuStream_t stream); + void RedOpCreatePreMulSum(ncclRedOp_t* op, + void* scalar, + ncclDataType_t dtype, + ncclScalarResidence_t residence); + + void RedOpDestroy(ncclRedOp_t op); + void GroupStart(); void GroupEnd(); @@ -102,6 +112,8 @@ class NCCLCommContext final : public CommContext { private: DISABLE_COPY_AND_ASSIGN(NCCLCommContext); + int nccl_version_; + ncclComm_t nccl_comm_; std::unique_ptr dev_ctx_; diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 46a0136167e9e..4056e43e475c0 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1013,11 +1013,11 @@ set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120) set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT - 120) + 240) set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120) set_tests_properties(test_distributed_fused_lamb_op_with_gradient_merge - PROPERTIES TIMEOUT 120) + PROPERTIES TIMEOUT 240) set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120) set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 300) diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py index 671e11e7702fe..32ee6fd8b3958 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py @@ -41,6 +41,7 @@ def run_test( max_global_norm=-1.0, gradient_merge_steps=1, use_master_acc_grad=True, + need_env={}, ): temp_dir = tempfile.TemporaryDirectory() if not paddle.is_compiled_with_cuda(): @@ -54,6 +55,8 @@ def run_test( '-u', '-m', 'paddle.distributed.launch', + '--devices', + '0,1', '--log_dir', log_dir, get_test_file(), @@ -65,6 +68,7 @@ def run_test( os.environ['MAX_GLOBAL_NORM'] = str(max_global_norm) os.environ['GRADIENT_MERGE_STEPS'] = str(gradient_merge_steps) os.environ['USE_MASTER_ACC_GRAD'] = str(1 if use_master_acc_grad else 0) + os.environ.update(need_env) touch_file_env = 'SUCCESS_TOUCH_FILE' touch_file_name = os.path.join( @@ -87,6 +91,20 @@ def test_1(self): def test_2(self): run_test(clip_after_allreduce=False, max_global_norm=0.01) + def test_1_new_comm(self): + run_test( + clip_after_allreduce=True, + max_global_norm=0.01, + need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + ) + + def test_2_new_comm(self): + run_test( + clip_after_allreduce=False, + max_global_norm=0.01, + need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + ) + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py b/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py index 0c7096f5dae1a..f236be3a8d150 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_gradient_merge.py @@ -33,6 +33,23 @@ def test_gm_with_fp16_acc_grad(self): use_master_acc_grad=False, ) + def test_gm_new_comm(self): + run_test( + clip_after_allreduce=True, + max_global_norm=-1.0, + gradient_merge_steps=2, + need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + ) + + def test_gm_with_fp16_acc_grad_new_comm(self): + run_test( + clip_after_allreduce=True, + max_global_norm=-1.0, + gradient_merge_steps=2, + use_master_acc_grad=False, + need_env={"FLAGS_dynamic_static_unified_comm": "1"}, + ) + if __name__ == "__main__": unittest.main()