Skip to content

Commit

Permalink
TL/NCCL make ncclGroupEnd nb (#798)
Browse files Browse the repository at this point in the history
* TL/NCCL: make ncclGroupEnd nb

* REVIEW: review fixes

* TL/NCCL: default blocking and configurable

* REVIEW: second review fixes
  • Loading branch information
shimmybalsam authored Jul 19, 2023
1 parent 72c84c3 commit 2a9ed2a
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 83 deletions.
38 changes: 23 additions & 15 deletions src/components/tl/nccl/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -70,12 +70,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)

task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
if (count != 0) {
for (peer = 0; peer < size; peer++) {
NCCLCHECK_GOTO(ncclSend(sbuf, count * sdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
}
for (peer = 0; peer < size; peer++) {
Expand All @@ -86,10 +88,12 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)
NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(rbuf, displ * rdt_size),
count * rdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 1);
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -106,8 +110,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
task->super.post = ucc_tl_nccl_allgatherv_p2p_start;
*task_h = &task->super;
task->super.post = ucc_tl_nccl_allgatherv_p2p_start;
*task_h = &task->super;

return status;
}
Expand Down Expand Up @@ -144,7 +148,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task)
}
NCCLCHECK_GOTO(ncclAllGather(sbuf, scratch, max_count * rdt_size,
ncclChar, team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
for (peer = 0; peer < size; peer++) {
rcount = ucc_coll_args_get_count(args,
args->dst.info_v.counts, peer);
Expand Down Expand Up @@ -233,13 +238,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;
void *sbuf = args->src.info.buffer;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer;
size_t rdt_size, count, displ;
ucc_rank_t peer;
size_t rdt_size, count, displ;
ucc_rank_t peer;

task->super.status = UCC_INPROGRESS;
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
for (peer = 0; peer < size; peer++) {
count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer);
displ = ucc_coll_args_get_displacement(args,
Expand All @@ -251,9 +257,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
NCCLCHECK_GOTO(ncclBroadcast(sbuf, PTR_OFFSET(rbuf, displ * rdt_size),
count * rdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 1);
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -263,7 +271,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h)
{
ucc_status_t status = UCC_OK;
ucc_status_t status = UCC_OK;
ucc_tl_nccl_task_t *task;

status = ucc_tl_nccl_init_task(coll_args, team, &task);
Expand All @@ -272,6 +280,6 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args,
}

task->super.post = ucc_tl_nccl_allgatherv_bcast_start;
*task_h = &task->super;
*task_h = &task->super;
return status;
}
6 changes: 6 additions & 0 deletions src/components/tl/nccl/tl_nccl.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names)
},

{"BLOCKING", "1",
"If set to 0 will use non-blocking mode communicator behavior, "
"if set to 1 will use blocking mode",
ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking),
UCS_CONFIG_TYPE_BOOL},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t,
Expand Down
36 changes: 32 additions & 4 deletions src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates. 2021.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -42,6 +42,8 @@
#define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW
#define UCC_TL_NCCL_PROFILE_REQUEST_EVENT UCC_PROFILE_REQUEST_EVENT
#define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

typedef struct ucc_tl_nccl_iface {
ucc_tl_iface_t super;
Expand All @@ -63,6 +65,7 @@ typedef enum ucc_tl_nccl_completion_sync_type {
typedef struct ucc_tl_nccl_context_config {
ucc_tl_context_config_t super;
ucc_tl_nccl_completion_sync_type_t sync_type;
int nccl_cfg_blocking;
} ucc_tl_nccl_context_config_t;

typedef struct ucc_tl_nccl_lib {
Expand Down Expand Up @@ -92,6 +95,7 @@ typedef struct ucc_tl_nccl_team {
typedef struct ucc_tl_nccl_task {
ucc_coll_task_t super;
ucc_status_t host_status;
ucc_status_t nccl_progress_st;
ucc_status_t *dev_status;
void *completed;
union {
Expand Down Expand Up @@ -122,17 +126,41 @@ typedef struct ucc_tl_nccl_task {
UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);

#define NCCLCHECK_GOTO(_cmd, _label, _status, _lib) \
static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NOLINT
ucc_status_t *task_st, // NOLINT
ncclComm_t nccl_comm, //NOLINT
int check_nb) { //NOLINT
#if NCCL_USE_NON_BLOCKING
if (check_nb &&
(*nccl_status == ncclSuccess || *nccl_status == ncclInProgress)) {
ncclResult_t st = ncclCommGetAsyncError(nccl_comm, nccl_status);
if (st != ncclSuccess) {
return UCC_ERR_NO_MESSAGE;
}
if (ncclInProgress == *nccl_status) {
*task_st = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
}
#endif
return UCC_OK;
}

#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \
do { \
ncclResult_t e = _cmd; \
if (ncclSuccess != e) { \
_st = ucc_tl_nccl_check_nb(&e, _task_st, _comm, _check_nb); \
if (_st != UCC_INPROGRESS && ncclSuccess != e) { \
tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \
_status = UCC_ERR_NO_MESSAGE; \
_st = UCC_ERR_NO_MESSAGE; \
goto _label; \
} \
} while (0)

#define UCC_TL_NCCL_TEAM_LIB(_team) \
(ucc_derived_of((_team)->super.super.context->lib, ucc_tl_nccl_lib_t))

#define UCC_TL_NCCL_TEAM_CTX(_team) \
(ucc_derived_of((_team)->super.super.context, ucc_tl_nccl_context_t))

#endif
Loading

0 comments on commit 2a9ed2a

Please sign in to comment.