diff --git a/src/components/tl/nccl/allgatherv/allgatherv.c b/src/components/tl/nccl/allgatherv/allgatherv.c index fef021af3a..7d4453574e 100644 --- a/src/components/tl/nccl/allgatherv/allgatherv.c +++ b/src/components/tl/nccl/allgatherv/allgatherv.c @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See file LICENSE for terms. */ @@ -70,12 +70,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) task->super.status = UCC_INPROGRESS; UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); if (count != 0) { for (peer = 0; peer < size; peer++) { NCCLCHECK_GOTO(ncclSend(sbuf, count * sdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } } for (peer = 0; peer < size; peer++) { @@ -86,10 +88,12 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(rbuf, displ * rdt_size), count * rdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -106,8 +110,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, if (ucc_unlikely(status != UCC_OK)) { return status; } - task->super.post = ucc_tl_nccl_allgatherv_p2p_start; - *task_h = &task->super; + task->super.post = ucc_tl_nccl_allgatherv_p2p_start; + *task_h = &task->super; return status; } @@ -144,7 +148,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task) } NCCLCHECK_GOTO(ncclAllGather(sbuf, scratch, max_count * rdt_size, ncclChar, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { rcount = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer); @@ -233,13 +238,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; void *sbuf = args->src.info.buffer; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer; - size_t rdt_size, count, displ; - ucc_rank_t peer; + size_t rdt_size, count, displ; + ucc_rank_t peer; task->super.status = UCC_INPROGRESS; rdt_size = ucc_dt_size(args->dst.info_v.datatype); UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer); displ = ucc_coll_args_get_displacement(args, @@ -251,9 +257,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclBroadcast(sbuf, PTR_OFFSET(rbuf, displ * rdt_size), count * rdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -263,7 +271,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * team, ucc_coll_task_t ** task_h) { - ucc_status_t status = UCC_OK; + ucc_status_t status = UCC_OK; ucc_tl_nccl_task_t *task; status = ucc_tl_nccl_init_task(coll_args, team, &task); @@ -272,6 +280,6 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args, } task->super.post = ucc_tl_nccl_allgatherv_bcast_start; - *task_h = &task->super; + *task_h = &task->super; return status; } diff --git a/src/components/tl/nccl/tl_nccl.c b/src/components/tl/nccl/tl_nccl.c index 8465c66e20..8e71cdc1e2 100644 --- a/src/components/tl/nccl/tl_nccl.c +++ b/src/components/tl/nccl/tl_nccl.c @@ -39,6 +39,12 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = { UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names) }, + {"BLOCKING", "1", + "If set to 0 will use non-blocking mode communicator behavior, " + "if set to 1 will use blocking mode", + ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking), + UCS_CONFIG_TYPE_BOOL}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 6f3daee474..06f32c0371 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) Facebook, Inc. and its affiliates. 2021. * * See file LICENSE for terms. @@ -42,6 +42,8 @@ #define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW #define UCC_TL_NCCL_PROFILE_REQUEST_EVENT UCC_PROFILE_REQUEST_EVENT #define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE +#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) +#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB typedef struct ucc_tl_nccl_iface { ucc_tl_iface_t super; @@ -63,6 +65,7 @@ typedef enum ucc_tl_nccl_completion_sync_type { typedef struct ucc_tl_nccl_context_config { ucc_tl_context_config_t super; ucc_tl_nccl_completion_sync_type_t sync_type; + int nccl_cfg_blocking; } ucc_tl_nccl_context_config_t; typedef struct ucc_tl_nccl_lib { @@ -92,6 +95,7 @@ typedef struct ucc_tl_nccl_team { typedef struct ucc_tl_nccl_task { ucc_coll_task_t super; ucc_status_t host_status; + ucc_status_t nccl_progress_st; ucc_status_t *dev_status; void *completed; union { @@ -122,12 +126,33 @@ typedef struct ucc_tl_nccl_task { UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *, const ucc_base_team_params_t *); -#define NCCLCHECK_GOTO(_cmd, _label, _status, _lib) \ +static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NOLINT + ucc_status_t *task_st, // NOLINT + ncclComm_t nccl_comm, //NOLINT + int check_nb) { //NOLINT +#if NCCL_USE_NON_BLOCKING + if (check_nb && + (*nccl_status == ncclSuccess || *nccl_status == ncclInProgress)) { + ncclResult_t st = ncclCommGetAsyncError(nccl_comm, nccl_status); + if (st != ncclSuccess) { + return UCC_ERR_NO_MESSAGE; + } + if (ncclInProgress == *nccl_status) { + *task_st = UCC_INPROGRESS; + return UCC_INPROGRESS; + } + } +#endif + return UCC_OK; +} + +#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \ do { \ ncclResult_t e = _cmd; \ - if (ncclSuccess != e) { \ + _st = ucc_tl_nccl_check_nb(&e, _task_st, _comm, _check_nb); \ + if (_st != UCC_INPROGRESS && ncclSuccess != e) { \ tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \ - _status = UCC_ERR_NO_MESSAGE; \ + _st = UCC_ERR_NO_MESSAGE; \ goto _label; \ } \ } while (0) @@ -135,4 +160,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *, #define UCC_TL_NCCL_TEAM_LIB(_team) \ (ucc_derived_of((_team)->super.super.context->lib, ucc_tl_nccl_lib_t)) +#define UCC_TL_NCCL_TEAM_CTX(_team) \ + (ucc_derived_of((_team)->super.super.context, ucc_tl_nccl_context_t)) + #endif diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index ba3d246489..6e9daa72d1 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -258,7 +258,7 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; ptrdiff_t sbuf = (ptrdiff_t)args->src.info.buffer; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info.buffer; - size_t data_size; + size_t data_size; ucc_rank_t peer; task->super.status = UCC_INPROGRESS; @@ -270,16 +270,20 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task) return ucc_task_complete(&task->super); } UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoall_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < gsize; peer++) { NCCLCHECK_GOTO(ncclSend((void *)(sbuf + peer * data_size), data_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); NCCLCHECK_GOTO(ncclRecv((void *)(rbuf + peer * data_size), data_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -292,7 +296,7 @@ ucc_status_t ucc_tl_nccl_alltoall_init(ucc_tl_nccl_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_alltoall_start; + task->super.post = ucc_tl_nccl_alltoall_start; return UCC_OK; } @@ -306,14 +310,15 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) ucc_status_t status = UCC_OK; ptrdiff_t sbuf = (ptrdiff_t)args->src.info_v.buffer; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer; - size_t sdt_size, rdt_size, count, displ; + size_t sdt_size, rdt_size, count, displ; ucc_rank_t peer; task->super.status = UCC_INPROGRESS; sdt_size = ucc_dt_size(args->src.info_v.datatype); rdt_size = ucc_dt_size(args->dst.info_v.datatype); UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoallv_start", 0); - NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); for (peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) { count = ucc_coll_args_get_count(args, args->src.info_v.counts, peer); if (count != 0) { @@ -322,7 +327,8 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend((void *)(sbuf + displ * sdt_size), count * sdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer); if (count != 0) { @@ -331,10 +337,12 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv((void *)(rbuf + displ * rdt_size), count * rdt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -347,7 +355,7 @@ ucc_status_t ucc_tl_nccl_alltoallv_init(ucc_tl_nccl_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_alltoallv_start; + task->super.post = ucc_tl_nccl_alltoallv_start; return UCC_OK; } @@ -375,7 +383,8 @@ ucc_status_t ucc_tl_nccl_allreduce_start(ucc_coll_task_t *coll_task) 0); NCCLCHECK_GOTO(ncclAllReduce(src, dst, count, dt, op, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -397,7 +406,7 @@ ucc_status_t ucc_tl_nccl_allreduce_init(ucc_tl_nccl_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_allreduce_start; + task->super.post = ucc_tl_nccl_allreduce_start; return UCC_OK; } @@ -425,7 +434,8 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task) UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgather_start", 0); NCCLCHECK_GOTO(ncclAllGather(src, dst, count / size, dt, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -483,7 +493,8 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) size = (ucc_rank_t)args->active_set.size; if (root == rank) { NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (ucc_ep_map_eval(map, peer) == rank) { continue; @@ -491,19 +502,22 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend(src, count, dt, ucc_ep_map_eval(map, peer), team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } else { NCCLCHECK_GOTO(ncclRecv(src, count, dt, root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } } else { NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: @@ -544,7 +558,8 @@ ucc_status_t ucc_tl_nccl_reduce_scatter_start(ucc_coll_task_t *coll_task) } NCCLCHECK_GOTO(ncclReduceScatter(src, dst, count, dt, op, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -597,7 +612,8 @@ ucc_status_t ucc_tl_nccl_reduce_start(ucc_coll_task_t *coll_task) task->super.status = UCC_INPROGRESS; NCCLCHECK_GOTO(ncclReduce(src, dst, count, nccl_dt, op, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: return status; @@ -658,7 +674,7 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info.buffer; void *src = args->src.info.buffer; ucc_status_t status = UCC_OK; - size_t send_size; + size_t send_size; ucc_rank_t peer; if (rank == args->root) { @@ -677,7 +693,8 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -685,14 +702,17 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(dst, peer * send_size), send_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 1); } else { NCCLCHECK_GOTO(ncclSend(src, send_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -702,8 +722,8 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task) ucc_status_t ucc_tl_nccl_gather_init(ucc_tl_nccl_task_t *task) { - task->super.post = ucc_tl_nccl_gather_start; - return UCC_OK; + task->super.post = ucc_tl_nccl_gather_start; + return UCC_OK; } ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) @@ -718,7 +738,7 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info_v.buffer; void *src = args->src.info.buffer; ucc_status_t status = UCC_OK; - size_t count, displ, dt_size; + size_t count, displ, dt_size; ucc_rank_t peer; if (rank == args->root) { @@ -740,7 +760,8 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -752,15 +773,18 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(dst, displ * dt_size), count * dt_size, ncclChar, peer,team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 1); } else { NCCLCHECK_GOTO(ncclSend(src, args->src.info.count * dt_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -786,8 +810,8 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info.buffer; void *src = args->src.info.buffer; ucc_status_t status = UCC_OK; - size_t send_size; - ucc_rank_t peer; + size_t send_size; + ucc_rank_t peer; if (rank == args->root) { send_size = ucc_dt_size(args->src.info.datatype) * @@ -806,7 +830,8 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -814,14 +839,16 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend(PTR_OFFSET(src, peer * send_size), send_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } else { NCCLCHECK_GOTO(ncclRecv(dst, send_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); @@ -847,7 +874,7 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) void *dst = args->dst.info.buffer; void *src = args->src.info_v.buffer; ucc_status_t status = UCC_OK; - size_t count, displ, dt_size; + size_t count, displ, dt_size; ucc_rank_t peer; if (rank == args->root) { @@ -871,7 +898,8 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) exit_coll, status); } NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + UCC_TL_TEAM_LIB(team), &task->nccl_progress_st, + team->nccl_comm, 0); for (peer = 0; peer < size; peer++) { if (peer == args->root) { continue; @@ -883,14 +911,16 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task) NCCLCHECK_GOTO(ncclSend(PTR_OFFSET(src, displ * dt_size), count * dt_size, ncclChar, peer, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 0); } - NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, - UCC_TL_TEAM_LIB(team)); + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } else { NCCLCHECK_GOTO(ncclRecv(dst, args->dst.info.count * dt_size, ncclChar, args->root, team->nccl_comm, stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + exit_coll, status, UCC_TL_TEAM_LIB(team), + &task->nccl_progress_st, team->nccl_comm, 1); } task->super.status = UCC_INPROGRESS; status = ucc_tl_nccl_collective_sync(task, stream); diff --git a/src/components/tl/nccl/tl_nccl_context.c b/src/components/tl/nccl/tl_nccl_context.c index 0da81cb419..ef67bed9ba 100644 --- a/src/components/tl/nccl/tl_nccl_context.c +++ b/src/components/tl/nccl/tl_nccl_context.c @@ -11,11 +11,39 @@ #include "core/ucc_ee.h" #include "utils/arch/cpu.h" +static ucc_status_t ucc_tl_nccl_nb_progress(ucc_tl_nccl_task_t *task) { +#if NCCL_USE_NON_BLOCKING + ucc_tl_nccl_team_t *team = TASK_TEAM(task); + ncclResult_t nccl_status, st; + + if (task->nccl_progress_st == UCC_INPROGRESS) { + st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (st != ncclSuccess || + (nccl_status != ncclSuccess && nccl_status != ncclInProgress)) { + tl_error(UCC_TL_TEAM_LIB(team), "NCCL error %d %s", + st != ncclSuccess ? st : nccl_status, + ncclGetErrorString(st != ncclSuccess ? st : nccl_status)); + return UCC_ERR_NO_MESSAGE; + } + if (nccl_status == ncclInProgress) { + return UCC_INPROGRESS; + } + } +#endif + return UCC_OK; +} + void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); ucc_status_t status; + status = ucc_tl_nccl_nb_progress(task); + if (status != UCC_OK) { + coll_task->status = status; + return; + } + ucc_assert(task->completed != NULL); status = ucc_ec_event_test(task->completed, UCC_EE_CUDA_STREAM); coll_task->status = status; @@ -29,6 +57,13 @@ void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task) void ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_status_t status; + + status = ucc_tl_nccl_nb_progress(task); + if (status != UCC_OK) { + coll_task->status = status; + return; + } coll_task->status = task->host_status; #ifdef HAVE_PROFILING_TL_NCCL @@ -49,7 +84,8 @@ static void ucc_tl_nccl_req_mpool_obj_init(ucc_mpool_t *mp, void *obj, ucc_tl_nccl_task_t *req = (ucc_tl_nccl_task_t*) obj; ucc_coll_task_construct(&req->super); - req->super.progress = ucc_tl_nccl_event_collective_progress; + req->super.progress = ucc_tl_nccl_event_collective_progress; + req->nccl_progress_st = UCC_OK; } @@ -91,7 +127,8 @@ static void ucc_tl_nccl_req_mapped_mpool_obj_init(ucc_mpool_t *mp, void *obj, req->super.status = UCC_ERR_NO_MESSAGE; } ucc_coll_task_construct(&req->super); - req->super.progress = ucc_tl_nccl_driver_collective_progress; + req->super.progress = ucc_tl_nccl_driver_collective_progress; + req->nccl_progress_st = UCC_OK; } static ucc_mpool_ops_t ucc_tl_nccl_req_mapped_mpool_ops = { diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index 6b67ad988f..af2aff2ac6 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -12,9 +12,6 @@ #include "coll_score/ucc_coll_score.h" #include "utils/arch/cuda_def.h" -#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3) -#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB - UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context, const ucc_base_team_params_t *params) { @@ -70,7 +67,7 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t); #if NCCL_USE_NON_BLOCKING - ncclResult_t nccl_status; + ncclResult_t nccl_status, st; if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) { goto check_finalize; @@ -86,16 +83,16 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team) #if NCCL_USE_NON_BLOCKING ncclCommFinalize(team->nccl_comm); check_finalize: - ncclCommGetAsyncError(team->nccl_comm, &nccl_status); - if (nccl_status == ncclInProgress) { - team->comm_state = UCC_INPROGRESS; - return UCC_INPROGRESS; - } - if (nccl_status != ncclSuccess) { - tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status, - ncclGetErrorString(nccl_status)); + st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (st != ncclSuccess || (nccl_status != ncclSuccess)) { + tl_debug(tl_team->context->lib, "NCCL error %d %s", + st != ncclSuccess ? st : nccl_status, + ncclGetErrorString(st != ncclSuccess ? st : nccl_status)); ncclCommAbort(team->nccl_comm); return UCC_ERR_NO_MESSAGE; + } else if (nccl_status == ncclInProgress) { + team->comm_state = UCC_INPROGRESS; + return UCC_INPROGRESS; } else { ncclCommDestroy(team->nccl_comm); } @@ -120,6 +117,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) #if NCCL_USE_NON_BLOCKING ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER; + ncclResult_t st; if (team->comm_state == UCC_INPROGRESS) { goto ncclInitStage; @@ -150,7 +148,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream, cudaStreamNonBlocking), free_unique_id, status); #if NCCL_USE_NON_BLOCKING - nccl_cfg.blocking = 0; + nccl_cfg.blocking = UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking; nccl_status = ncclCommInitRankConfig(&team->nccl_comm, UCC_TL_TEAM_SIZE(team), team->unique_id[0], @@ -160,7 +158,10 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) goto free_stream; } ncclInitStage: - ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status); + if (st != ncclSuccess) { + nccl_status = st; + } if (nccl_status == ncclInProgress){ team->comm_state = UCC_INPROGRESS; return UCC_INPROGRESS;