Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/NCCL make ncclGroupEnd nb #798

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions src/components/tl/nccl/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -70,12 +70,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)

task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
if (count != 0) {
for (peer = 0; peer < size; peer++) {
NCCLCHECK_GOTO(ncclSend(sbuf, count * sdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
}
for (peer = 0; peer < size; peer++) {
Expand All @@ -86,10 +88,12 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)
NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(rbuf, displ * rdt_size),
count * rdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 1);
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -106,8 +110,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
task->super.post = ucc_tl_nccl_allgatherv_p2p_start;
*task_h = &task->super;
task->super.post = ucc_tl_nccl_allgatherv_p2p_start;
*task_h = &task->super;

return status;
}
Expand Down Expand Up @@ -144,7 +148,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task)
}
NCCLCHECK_GOTO(ncclAllGather(sbuf, scratch, max_count * rdt_size,
ncclChar, team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
for (peer = 0; peer < size; peer++) {
rcount = ucc_coll_args_get_count(args,
args->dst.info_v.counts, peer);
Expand Down Expand Up @@ -233,13 +238,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;
void *sbuf = args->src.info.buffer;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer;
size_t rdt_size, count, displ;
ucc_rank_t peer;
size_t rdt_size, count, displ;
ucc_rank_t peer;

task->super.status = UCC_INPROGRESS;
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
for (peer = 0; peer < size; peer++) {
count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer);
displ = ucc_coll_args_get_displacement(args,
Expand All @@ -251,9 +257,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
NCCLCHECK_GOTO(ncclBroadcast(sbuf, PTR_OFFSET(rbuf, displ * rdt_size),
count * rdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 1);
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -263,7 +271,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h)
{
ucc_status_t status = UCC_OK;
ucc_status_t status = UCC_OK;
ucc_tl_nccl_task_t *task;

status = ucc_tl_nccl_init_task(coll_args, team, &task);
Expand All @@ -272,6 +280,6 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args,
}

task->super.post = ucc_tl_nccl_allgatherv_bcast_start;
*task_h = &task->super;
*task_h = &task->super;
return status;
}
6 changes: 6 additions & 0 deletions src/components/tl/nccl/tl_nccl.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names)
},

{"BLOCKING", "1",
"If set to 0 will use non-blocking mode communicator behavior, "
"if set to 1 will use blocking mode",
ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking),
UCS_CONFIG_TYPE_BOOL},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t,
Expand Down
36 changes: 32 additions & 4 deletions src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates. 2021.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -42,6 +42,8 @@
#define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW
#define UCC_TL_NCCL_PROFILE_REQUEST_EVENT UCC_PROFILE_REQUEST_EVENT
#define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

typedef struct ucc_tl_nccl_iface {
shimmybalsam marked this conversation as resolved.
Show resolved Hide resolved
ucc_tl_iface_t super;
Expand All @@ -63,6 +65,7 @@ typedef enum ucc_tl_nccl_completion_sync_type {
typedef struct ucc_tl_nccl_context_config {
ucc_tl_context_config_t super;
ucc_tl_nccl_completion_sync_type_t sync_type;
int nccl_cfg_blocking;
} ucc_tl_nccl_context_config_t;

typedef struct ucc_tl_nccl_lib {
Expand Down Expand Up @@ -92,6 +95,7 @@ typedef struct ucc_tl_nccl_team {
typedef struct ucc_tl_nccl_task {
ucc_coll_task_t super;
ucc_status_t host_status;
ucc_status_t nccl_progress_st;
ucc_status_t *dev_status;
void *completed;
union {
Expand Down Expand Up @@ -122,17 +126,41 @@ typedef struct ucc_tl_nccl_task {
UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);

#define NCCLCHECK_GOTO(_cmd, _label, _status, _lib) \
static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NOLINT
ucc_status_t *task_st, // NOLINT
ncclComm_t nccl_comm, //NOLINT
int check_nb) { //NOLINT
#if NCCL_USE_NON_BLOCKING
if (check_nb &&
(*nccl_status == ncclSuccess || *nccl_status == ncclInProgress)) {
ncclResult_t st = ncclCommGetAsyncError(nccl_comm, nccl_status);
if (st != ncclSuccess) {
return UCC_ERR_NO_MESSAGE;
}
if (ncclInProgress == *nccl_status) {
*task_st = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
}
#endif
return UCC_OK;
}

#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \
shimmybalsam marked this conversation as resolved.
Show resolved Hide resolved
do { \
ncclResult_t e = _cmd; \
if (ncclSuccess != e) { \
_st = ucc_tl_nccl_check_nb(&e, _task_st, _comm, _check_nb); \
if (_st != UCC_INPROGRESS && ncclSuccess != e) { \
tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \
_status = UCC_ERR_NO_MESSAGE; \
_st = UCC_ERR_NO_MESSAGE; \
goto _label; \
} \
} while (0)

#define UCC_TL_NCCL_TEAM_LIB(_team) \
(ucc_derived_of((_team)->super.super.context->lib, ucc_tl_nccl_lib_t))

#define UCC_TL_NCCL_TEAM_CTX(_team) \
(ucc_derived_of((_team)->super.super.context, ucc_tl_nccl_context_t))

#endif
Loading