diff --git a/src/components/tl/nccl/allgatherv/allgatherv.c b/src/components/tl/nccl/allgatherv/allgatherv.c index 5c1883fe0a..7d4453574e 100644 --- a/src/components/tl/nccl/allgatherv/allgatherv.c +++ b/src/components/tl/nccl/allgatherv/allgatherv.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -70,12 +70,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) task->super.status = UCC_INPROGRESS; UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); if (count != 0) { for (peer = 0; peer < size; peer++) { NCCLCHECK_GOTO(ncclSend(sbuf, count * sdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } } for (peer = 0; peer < size; peer++) { @@ -86,16 +88,12 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(rbuf, displ * rdt_size), count * rdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -150,7 +148,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task) } NCCLCHECK_GOTO(ncclAllGather(sbuf, scratch, max_count * rdt_size, ncclChar, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { rcount = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer); @@ -245,7 +244,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task) task->super.status = UCC_INPROGRESS; rdt_size = ucc_dt_size(args->dst.info_v.datatype); UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer); displ = ucc_coll_args_get_displacement(args, @@ -257,15 +257,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclBroadcast(sbuf, PTR_OFFSET(rbuf, displ * rdt_size), count * rdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 2425f05fd9..458965a84c 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Facebook, Inc. and its affiliates. 2021. * * See file LICENSE for terms. @@ -44,6 +44,7 @@ #define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE #define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) #define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB + typedef struct ucc_tl_nccl_iface { ucc_tl_iface_t super; } ucc_tl_nccl_iface_t; @@ -124,23 +125,31 @@ typedef struct ucc_tl_nccl_task { UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); -#define NCCLCHECK_GOTO(_cmd, _label, _status, _lib) \ - do { \ - ncclResult_t e = _cmd; \ - if (ncclSuccess != e) { \ - tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \ - _status = UCC_ERR_NO_MESSAGE; \ - goto _label; \ - } \ - } while (0) +static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NOLINT + ucc_status_t *task_st, // NOLINT + ncclComm_t nccl_comm, //NOLINT + int check_nb) { //NOLINT +#if NCCL_USE_NON_BLOCKING + if (check_nb && + (*nccl_status == ncclSuccess || *nccl_status == ncclInProgress)) { + ncclResult_t st = ncclCommGetAsyncError(nccl_comm, nccl_status); + if (st != ncclSuccess) { + return UCC_ERR_NO_MESSAGE; + } + if (ncclInProgress == *nccl_status) { + *task_st = UCC_INPROGRESS; + return UCC_INPROGRESS; + } + } +#endif + return UCC_OK; +} -#define NCCLCHECK_INPROGRESS_GOTO(_cmd, _label, _st, _lib, _task_st, _comm) \ +#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \ do { \ ncclResult_t e = _cmd; \ - ncclCommGetAsyncError(_comm, &e); \ - if (ncclInProgress == e) { \ - _task_st = UCC_INPROGRESS; \ - } else if (ncclSuccess != e) { \ + _st = ucc_tl_nccl_check_nb(&e, _task_st, _comm, _check_nb); \ + if (_st != UCC_INPROGRESS && ncclSuccess != e) { \ tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \ _st = UCC_ERR_NO_MESSAGE; \ goto _label; \ diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index 17d64d3c3a..6e9daa72d1 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -270,22 +270,20 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task) return ucc_task_complete(&task->super); } UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoall_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < gsize; peer++) { NCCLCHECK_GOTO(ncclSend((void *)(sbuf + peer * data_size), data_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); NCCLCHECK_GOTO(ncclRecv((void *)(rbuf + peer * data_size), data_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -319,7 +317,8 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) sdt_size = ucc_dt_size(args->src.info_v.datatype); rdt_size = ucc_dt_size(args->dst.info_v.datatype); UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoallv_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) { count = ucc_coll_args_get_count(args, args->src.info_v.counts, peer); if (count != 0) { @@ -328,7 +327,8 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend((void *)(sbuf + displ * sdt_size), count * sdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer); if (count != 0) { @@ -337,16 +337,12 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv((void *)(rbuf + displ * rdt_size), count * rdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -387,7 +383,8 @@ ucc_status_t ucc_tl_nccl_allreduce_start(ucc_coll_task_t *coll_task) 0); NCCLCHECK_GOTO(ncclAllReduce(src, dst, count, dt, op, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -437,7 +434,8 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task) UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgather_start", 0); NCCLCHECK_GOTO(ncclAllGather(src, dst, count / size, dt, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -495,7 +493,8 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) size = (ucc_rank_t)args->active_set.size; if (root == rank) { NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (ucc_ep_map_eval(map, peer) == rank) { continue; @@ -503,32 +502,22 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend(src, count, dt, ucc_ep_map_eval(map, peer), team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } else { -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclRecv(src, count, dt, root, - team->nccl_comm, stream), - exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else NCCLCHECK_GOTO(ncclRecv(src, count, dt, root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } } else { NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: @@ -569,7 +558,8 @@ ucc_status_t ucc_tl_nccl_reduce_scatter_start(ucc_coll_task_t *coll_task) } NCCLCHECK_GOTO(ncclReduceScatter(src, dst, count, dt, op, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -622,7 +612,8 @@ ucc_status_t ucc_tl_nccl_reduce_start(ucc_coll_task_t *coll_task) task->super.status = UCC_INPROGRESS; NCCLCHECK_GOTO(ncclReduce(src, dst, count, nccl_dt, op, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -702,7 +693,8 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -710,27 +702,17 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(dst, peer * send_size), send_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 1); } else { -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclSend(src, send_size, ncclChar, - args->root, team->nccl_comm, - stream), - exit_coll, status, UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else NCCLCHECK_GOTO(ncclSend(src, send_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -778,7 +760,8 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -790,28 +773,18 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(dst, displ * dt_size), count * dt_size, ncclChar, peer,team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 1); } else { -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclSend(src, args->src.info.count * dt_size, - ncclChar, args->root, - team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else NCCLCHECK_GOTO(ncclSend(src, args->src.info.count * dt_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -857,7 +830,8 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -865,27 +839,16 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend(PTR_OFFSET(src, peer * send_size), send_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } else { -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclRecv(dst, send_size, ncclChar, - args->root, team->nccl_comm, - stream), - exit_coll, status, UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else NCCLCHECK_GOTO(ncclRecv(dst, send_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -935,7 +898,8 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -947,27 +911,16 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend(PTR_OFFSET(src, displ * dt_size), count * dt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } else { -#if NCCL_USE_NON_BLOCKING - NCCLCHECK_INPROGRESS_GOTO(ncclRecv(dst, args->dst.info.count * dt_size, - ncclChar, args->root, - team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team), - task->nccl_progress_st, team->nccl_comm); -#else NCCLCHECK_GOTO(ncclRecv(dst, args->dst.info.count * dt_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); -#endif + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); diff --git a/src/components/tl/nccl/tl_nccl_context.c b/src/components/tl/nccl/tl_nccl_context.c index 14982204e2..ef67bed9ba 100644 --- a/src/components/tl/nccl/tl_nccl_context.c +++ b/src/components/tl/nccl/tl_nccl_context.c @@ -11,37 +11,38 @@ #include "core/ucc_ee.h" #include "utils/arch/cpu.h" -#if NCCL_USE_NON_BLOCKING static ucc_status_t ucc_tl_nccl_nb_progress(ucc_tl_nccl_task_t *task) { +#if NCCL_USE_NON_BLOCKING ucc_tl_nccl_team_t *team = TASK_TEAM(task); - ncclResult_t nccl_status; + ncclResult_t nccl_status, st; if (task->nccl_progress_st == UCC_INPROGRESS) { - ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (st != ncclSuccess || + (nccl_status != ncclSuccess && nccl_status != ncclInProgress)) { + tl_error(UCC_TL_TEAM_LIB(team), "NCCL error %d %s", + st != ncclSuccess ? st : nccl_status, + ncclGetErrorString(st != ncclSuccess ? st : nccl_status)); + return UCC_ERR_NO_MESSAGE; + } if (nccl_status == ncclInProgress) { return UCC_INPROGRESS; } - if (nccl_status != ncclSuccess) { - tl_error(UCC_TL_TEAM_LIB(team), "NCCL error %d %s", nccl_status, - ncclGetErrorString(nccl_status)); - return UCC_ERR_NO_MESSAGE; - } } +#endif return UCC_OK; } -#endif void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); ucc_status_t status; -#if NCCL_USE_NON_BLOCKING + status = ucc_tl_nccl_nb_progress(task); if (status != UCC_OK) { coll_task->status = status; return; } -#endif ucc_assert(task->completed != NULL); status = ucc_ec_event_test(task->completed, UCC_EE_CUDA_STREAM); @@ -56,7 +57,6 @@ void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task) void ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); -#if NCCL_USE_NON_BLOCKING ucc_status_t status; status = ucc_tl_nccl_nb_progress(task); @@ -64,7 +64,6 @@ void ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task) coll_task->status = status; return; } -#endif coll_task->status = task->host_status; #ifdef HAVE_PROFILING_TL_NCCL diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index ee082ab79d..0497490d83 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -67,7 +67,7 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); #if NCCL_USE_NON_BLOCKING - ncclResult_t nccl_status; + ncclResult_t nccl_status, st; if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) { goto check_finalize; @@ -83,16 +83,16 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) #if NCCL_USE_NON_BLOCKING ncclCommFinalize(team->nccl_comm); check_finalize: - ncclCommGetAsyncError(team->nccl_comm, &nccl_status); - if (nccl_status == ncclInProgress) { - team->comm_state = UCC_INPROGRESS; - return UCC_INPROGRESS; - } - if (nccl_status != ncclSuccess) { - tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status, - ncclGetErrorString(nccl_status)); + st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (st != ncclSuccess || (nccl_status != ncclSuccess)) { + tl_debug(tl_team->context->lib, "NCCL error %d %s", + st != ncclSuccess ? st : nccl_status, + ncclGetErrorString(st != ncclSuccess ? st : nccl_status)); ncclCommAbort(team->nccl_comm); return UCC_ERR_NO_MESSAGE; + } else if (nccl_status == ncclInProgress) { + team->comm_state = UCC_INPROGRESS; + return UCC_INPROGRESS; } else { ncclCommDestroy(team->nccl_comm); } @@ -117,6 +117,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) #if NCCL_USE_NON_BLOCKING ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER; + ncclResult_t st; if (team->comm_state == UCC_INPROGRESS) { goto ncclInitStage; @@ -157,7 +158,10 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) goto free_stream; } ncclInitStage: - ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (st != ncclSuccess) { + nccl_status = st; + } if (nccl_status == ncclInProgress){ team->comm_state = UCC_INPROGRESS; return UCC_INPROGRESS;