From 90bbc66755e6241342fa8d5ceac7ae4b613d69e8 Mon Sep 17 00:00:00 2001 From: Shimmy Balsam Date: Wed, 21 Jun 2023 16:02:05 +0300 Subject: [PATCH] TL/NCCL: make ncclGroupEnd nb --- .../tl/nccl/allgatherv/allgatherv.c | 24 +++- src/components/tl/nccl/tl_nccl.h | 17 ++- src/components/tl/nccl/tl_nccl_coll.c | 121 ++++++++++++++---- src/components/tl/nccl/tl_nccl_context.c | 42 +++++- src/components/tl/nccl/tl_nccl_team.c | 3 - 5 files changed, 173 insertions(+), 34 deletions(-) diff --git a/src/components/tl/nccl/allgatherv/allgatherv.c b/src/components/tl/nccl/allgatherv/allgatherv.c index fef021af3a..5c1883fe0a 100644 --- a/src/components/tl/nccl/allgatherv/allgatherv.c +++ b/src/components/tl/nccl/allgatherv/allgatherv.c @@ -89,7 +89,13 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) exit_coll, status, UCC_TL_TEAM_LIB(team)); } } +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -106,8 +112,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, if (ucc_unlikely(status != UCC_OK)) { return status; } - task->super.post = ucc_tl_nccl_allgatherv_p2p_start; - *task_h = &task->super; + task->super.post = ucc_tl_nccl_allgatherv_p2p_start; + *task_h = &task->super; return status; } @@ -233,8 +239,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; void *sbuf = args->src.info.buffer; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer; - size_t rdt_size, count, displ; - ucc_rank_t peer; + size_t rdt_size, count, displ; + ucc_rank_t peer; task->super.status = UCC_INPROGRESS; rdt_size = ucc_dt_size(args->dst.info_v.datatype); @@ -253,7 +259,13 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task) team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -263,7 +275,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * team, ucc_coll_task_t ** task_h) { - ucc_status_t status = UCC_OK; + ucc_status_t status = UCC_OK; ucc_tl_nccl_task_t *task; status = ucc_tl_nccl_init_task(coll_args, team, &task); @@ -272,6 +284,6 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args, } task->super.post = ucc_tl_nccl_allgatherv_bcast_start; - *task_h = &task->super; + *task_h = &task->super; return status; } diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 6f3daee474..2425f05fd9 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -42,7 +42,8 @@ #define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW #define UCC_TL_NCCL_PROFILE_REQUEST_EVENT UCC_PROFILE_REQUEST_EVENT #define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE - +#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) +#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB typedef struct ucc_tl_nccl_iface { ucc_tl_iface_t super; } ucc_tl_nccl_iface_t; @@ -92,6 +93,7 @@ typedef struct ucc_tl_nccl_team { typedef struct ucc_tl_nccl_task { ucc_coll_task_t super; ucc_status_t host_status; + ucc_status_t nccl_progress_st; ucc_status_t *dev_status; void *completed; union { @@ -132,6 +134,19 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *, } \ } while (0) +#define NCCLCHECK_INPROGRESS_GOTO(_cmd, _label, _st, _lib, _task_st, _comm) \ + do { \ + ncclResult_t e = _cmd; \ + ncclCommGetAsyncError(_comm, &e); \ + if (ncclInProgress == e) { \ + _task_st = UCC_INPROGRESS; \ + } else if (ncclSuccess != e) { \ + tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \ + _st = UCC_ERR_NO_MESSAGE; \ + goto _label; \ + } \ + } while (0) + #define UCC_TL_NCCL_TEAM_LIB(_team) \ (ucc_derived_of((_team)->super.super.context->lib, ucc_tl_nccl_lib_t)) diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index ba3d246489..17d64d3c3a 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -258,7 +258,7 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; ptrdiff_t sbuf = (ptrdiff_t)args->src.info.buffer; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info.buffer; - size_t data_size; + size_t data_size; ucc_rank_t peer; task->super.status = UCC_INPROGRESS; @@ -279,7 +279,13 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task) ncclChar, peer, team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -292,7 +298,7 @@ ucc_status_t ucc_tl_nccl_alltoall_init(ucc_tl_nccl_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_alltoall_start; + task->super.post = ucc_tl_nccl_alltoall_start; return UCC_OK; } @@ -306,7 +312,7 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; ptrdiff_t sbuf = (ptrdiff_t)args->src.info_v.buffer; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer; - size_t sdt_size, rdt_size, count, displ; + size_t sdt_size, rdt_size, count, displ; ucc_rank_t peer; task->super.status = UCC_INPROGRESS; @@ -334,7 +340,13 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) exit_coll, status, UCC_TL_TEAM_LIB(team)); } } +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -347,7 +359,7 @@ ucc_status_t ucc_tl_nccl_alltoallv_init(ucc_tl_nccl_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_alltoallv_start; + task->super.post = ucc_tl_nccl_alltoallv_start; return UCC_OK; } @@ -397,7 +409,7 @@ ucc_status_t ucc_tl_nccl_allreduce_init(ucc_tl_nccl_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_allreduce_start; + task->super.post = ucc_tl_nccl_allreduce_start; return UCC_OK; } @@ -493,12 +505,25 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } else { +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclRecv(src, count, dt, root, + team->nccl_comm, stream), + exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclRecv(src, count, dt, root, team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } } else { NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm, @@ -658,7 +683,7 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info.buffer; void *src = args->src.info.buffer; ucc_status_t status = UCC_OK; - size_t send_size; + size_t send_size; ucc_rank_t peer; if (rank == args->root) { @@ -687,12 +712,25 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } else { +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclSend(src, send_size, ncclChar, + args->root, team->nccl_comm, + stream), + exit_coll, status, UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclSend(src, send_size, ncclChar, args->root, team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -702,8 +740,8 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_nccl_gather_init(ucc_tl_nccl_task_t *task) { - task->super.post = ucc_tl_nccl_gather_start; - return UCC_OK; + task->super.post = ucc_tl_nccl_gather_start; + return UCC_OK; } ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) @@ -718,7 +756,7 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info_v.buffer; void *src = args->src.info.buffer; ucc_status_t status = UCC_OK; - size_t count, displ, dt_size; + size_t count, displ, dt_size; ucc_rank_t peer; if (rank == args->root) { @@ -754,13 +792,26 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) peer,team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } else { +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclSend(src, args->src.info.count * dt_size, + ncclChar, args->root, + team->nccl_comm, stream), + exit_coll, status, UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclSend(src, args->src.info.count * dt_size, ncclChar, args->root, team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -786,8 +837,8 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info.buffer; void *src = args->src.info.buffer; ucc_status_t status = UCC_OK; - size_t send_size; - ucc_rank_t peer; + size_t send_size; + ucc_rank_t peer; if (rank == args->root) { send_size = ucc_dt_size(args->src.info.datatype) * @@ -816,12 +867,25 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } else { +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclRecv(dst, send_size, ncclChar, + args->root, team->nccl_comm, + stream), + exit_coll, status, UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclRecv(dst, send_size, ncclChar, args->root, team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -847,7 +911,7 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info.buffer; void *src = args->src.info_v.buffer; ucc_status_t status = UCC_OK; - size_t count, displ, dt_size; + size_t count, displ, dt_size; ucc_rank_t peer; if (rank == args->root) { @@ -885,12 +949,25 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } else { +#if NCCL_USE_NON_BLOCKING + NCCLCHECK_INPROGRESS_GOTO(ncclRecv(dst, args->dst.info.count * dt_size, + ncclChar, args->root, + team->nccl_comm, stream), + exit_coll, status, UCC_TL_TEAM_LIB(team), + task->nccl_progress_st, team->nccl_comm); +#else NCCLCHECK_GOTO(ncclRecv(dst, args->dst.info.count * dt_size, ncclChar, args->root, team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); +#endif } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); diff --git a/src/components/tl/nccl/tl_nccl_context.c b/src/components/tl/nccl/tl_nccl_context.c index 0da81cb419..14982204e2 100644 --- a/src/components/tl/nccl/tl_nccl_context.c +++ b/src/components/tl/nccl/tl_nccl_context.c @@ -11,10 +11,37 @@ #include "core/ucc_ee.h" #include "utils/arch/cpu.h" +#if NCCL_USE_NON_BLOCKING +static ucc_status_t ucc_tl_nccl_nb_progress(ucc_tl_nccl_task_t *task) { + ucc_tl_nccl_team_t *team = TASK_TEAM(task); + ncclResult_t nccl_status; + + if (task->nccl_progress_st == UCC_INPROGRESS) { + ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (nccl_status == ncclInProgress) { + return UCC_INPROGRESS; + } + if (nccl_status != ncclSuccess) { + tl_error(UCC_TL_TEAM_LIB(team), "NCCL error %d %s", nccl_status, + ncclGetErrorString(nccl_status)); + return UCC_ERR_NO_MESSAGE; + } + } + return UCC_OK; +} +#endif + void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); ucc_status_t status; +#if NCCL_USE_NON_BLOCKING + status = ucc_tl_nccl_nb_progress(task); + if (status != UCC_OK) { + coll_task->status = status; + return; + } +#endif ucc_assert(task->completed != NULL); status = ucc_ec_event_test(task->completed, UCC_EE_CUDA_STREAM); @@ -29,6 +56,15 @@ void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task) void ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); +#if NCCL_USE_NON_BLOCKING + ucc_status_t status; + + status = ucc_tl_nccl_nb_progress(task); + if (status != UCC_OK) { + coll_task->status = status; + return; + } +#endif coll_task->status = task->host_status; #ifdef HAVE_PROFILING_TL_NCCL @@ -49,7 +85,8 @@ static void ucc_tl_nccl_req_mpool_obj_init(ucc_mpool_t *mp, void *obj, ucc_tl_nccl_task_t *req = (ucc_tl_nccl_task_t*) obj; ucc_coll_task_construct(&req->super); - req->super.progress = ucc_tl_nccl_event_collective_progress; + req->super.progress = ucc_tl_nccl_event_collective_progress; + req->nccl_progress_st = UCC_OK; } @@ -91,7 +128,8 @@ static void ucc_tl_nccl_req_mapped_mpool_obj_init(ucc_mpool_t *mp, void *obj, req->super.status = UCC_ERR_NO_MESSAGE; } ucc_coll_task_construct(&req->super); - req->super.progress = ucc_tl_nccl_driver_collective_progress; + req->super.progress = ucc_tl_nccl_driver_collective_progress; + req->nccl_progress_st = UCC_OK; } static ucc_mpool_ops_t ucc_tl_nccl_req_mapped_mpool_ops = { diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index 6b67ad988f..ee082ab79d 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -12,9 +12,6 @@ #include "coll_score/ucc_coll_score.h" #include "utils/arch/cuda_def.h" -#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) -#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB - UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, const ucc_base_team_params_t *params) {