Skip to content

Commit

Permalink
TL/NCCL: make ncclGroupEnd nb
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed Jun 25, 2023
1 parent 1280664 commit 6239705
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 39 deletions.
24 changes: 18 additions & 6 deletions src/components/tl/nccl/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
}
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -106,8 +112,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args,
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
task->super.post = ucc_tl_nccl_allgatherv_p2p_start;
*task_h = &task->super;
task->super.post = ucc_tl_nccl_allgatherv_p2p_start;
*task_h = &task->super;

return status;
}
Expand Down Expand Up @@ -233,8 +239,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;
void *sbuf = args->src.info.buffer;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer;
size_t rdt_size, count, displ;
ucc_rank_t peer;
size_t rdt_size, count, displ;
ucc_rank_t peer;

task->super.status = UCC_INPROGRESS;
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
Expand All @@ -253,7 +259,13 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -263,7 +275,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h)
{
ucc_status_t status = UCC_OK;
ucc_status_t status = UCC_OK;
ucc_tl_nccl_task_t *task;

status = ucc_tl_nccl_init_task(coll_args, team, &task);
Expand All @@ -272,6 +284,6 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_init(ucc_base_coll_args_t *coll_args,
}

task->super.post = ucc_tl_nccl_allgatherv_bcast_start;
*task_h = &task->super;
*task_h = &task->super;
return status;
}
17 changes: 16 additions & 1 deletion src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
#define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW
#define UCC_TL_NCCL_PROFILE_REQUEST_EVENT UCC_PROFILE_REQUEST_EVENT
#define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE

#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB
typedef struct ucc_tl_nccl_iface {
ucc_tl_iface_t super;
} ucc_tl_nccl_iface_t;
Expand Down Expand Up @@ -92,6 +93,7 @@ typedef struct ucc_tl_nccl_team {
typedef struct ucc_tl_nccl_task {
ucc_coll_task_t super;
ucc_status_t host_status;
ucc_status_t nccl_progress_st;
ucc_status_t *dev_status;
void *completed;
union {
Expand Down Expand Up @@ -132,6 +134,19 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *,
} \
} while (0)

#define NCCLCHECK_INPROGRESS_GOTO(_cmd, _label, _st, _lib, _task_st, _comm) \
do { \
ncclResult_t e = _cmd; \
ncclCommGetAsyncError(_comm, &e); \
if (ncclInProgress == e) { \
_task_st = UCC_INPROGRESS; \
} else if (ncclSuccess != e) { \
tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \
_st = UCC_ERR_NO_MESSAGE; \
goto _label; \
} \
} while (0)

#define UCC_TL_NCCL_TEAM_LIB(_team) \
(ucc_derived_of((_team)->super.super.context->lib, ucc_tl_nccl_lib_t))

Expand Down
131 changes: 104 additions & 27 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;
ptrdiff_t sbuf = (ptrdiff_t)args->src.info.buffer;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info.buffer;
size_t data_size;
ucc_rank_t peer;
size_t data_size;
ucc_rank_t peer;

task->super.status = UCC_INPROGRESS;
data_size = (size_t)(args->src.info.count / gsize) *
Expand All @@ -279,7 +279,13 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
ncclChar, peer, team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -292,7 +298,7 @@ ucc_status_t ucc_tl_nccl_alltoall_init(ucc_tl_nccl_task_t *task)
return UCC_ERR_NOT_SUPPORTED;
}

task->super.post = ucc_tl_nccl_alltoall_start;
task->super.post = ucc_tl_nccl_alltoall_start;
return UCC_OK;
}

Expand All @@ -306,8 +312,8 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;
ptrdiff_t sbuf = (ptrdiff_t)args->src.info_v.buffer;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer;
size_t sdt_size, rdt_size, count, displ;
ucc_rank_t peer;
size_t sdt_size, rdt_size, count, displ;
ucc_rank_t peer;

task->super.status = UCC_INPROGRESS;
sdt_size = ucc_dt_size(args->src.info_v.datatype);
Expand All @@ -334,7 +340,13 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task)
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
}
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand All @@ -347,7 +359,7 @@ ucc_status_t ucc_tl_nccl_alltoallv_init(ucc_tl_nccl_task_t *task)
return UCC_ERR_NOT_SUPPORTED;
}

task->super.post = ucc_tl_nccl_alltoallv_start;
task->super.post = ucc_tl_nccl_alltoallv_start;
return UCC_OK;
}

Expand Down Expand Up @@ -397,7 +409,7 @@ ucc_status_t ucc_tl_nccl_allreduce_init(ucc_tl_nccl_task_t *task)
return UCC_ERR_NOT_SUPPORTED;
}

task->super.post = ucc_tl_nccl_allreduce_start;
task->super.post = ucc_tl_nccl_allreduce_start;
return UCC_OK;
}

Expand Down Expand Up @@ -493,12 +505,25 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task)
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team));
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
} else {
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclRecv(src, count, dt, root,
team->nccl_comm, stream),
exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclRecv(src, count, dt, root,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
}
} else {
NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm,
Expand Down Expand Up @@ -658,8 +683,8 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task)
void *dst = args->dst.info.buffer;
void *src = args->src.info.buffer;
ucc_status_t status = UCC_OK;
size_t send_size;
ucc_rank_t peer;
size_t send_size;
ucc_rank_t peer;

if (rank == args->root) {
send_size = ucc_dt_size(args->dst.info.datatype) *
Expand Down Expand Up @@ -687,12 +712,25 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task)
stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team));
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
} else {
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclSend(src, send_size, ncclChar,
args->root, team->nccl_comm,
stream),
exit_coll, status, UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclSend(src, send_size, ncclChar, args->root,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
}
task->super.status = UCC_INPROGRESS;
status = ucc_tl_nccl_collective_sync(task, stream);
Expand All @@ -702,8 +740,8 @@ ucc_status_t ucc_tl_nccl_gather_start(ucc_coll_task_t *coll_task)

ucc_status_t ucc_tl_nccl_gather_init(ucc_tl_nccl_task_t *task)
{
task->super.post = ucc_tl_nccl_gather_start;
return UCC_OK;
task->super.post = ucc_tl_nccl_gather_start;
return UCC_OK;
}

ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task)
Expand All @@ -718,8 +756,8 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task)
void *dst = args->dst.info_v.buffer;
void *src = args->src.info.buffer;
ucc_status_t status = UCC_OK;
size_t count, displ, dt_size;
ucc_rank_t peer;
size_t count, displ, dt_size;
ucc_rank_t peer;

if (rank == args->root) {
dt_size = ucc_dt_size(args->dst.info_v.datatype);
Expand Down Expand Up @@ -754,13 +792,26 @@ ucc_status_t ucc_tl_nccl_gatherv_start(ucc_coll_task_t *coll_task)
peer,team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team));
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
} else {
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclSend(src, args->src.info.count * dt_size,
ncclChar, args->root,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclSend(src, args->src.info.count * dt_size,
ncclChar, args->root, team->nccl_comm,
stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
}
task->super.status = UCC_INPROGRESS;
status = ucc_tl_nccl_collective_sync(task, stream);
Expand All @@ -786,8 +837,8 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task)
void *dst = args->dst.info.buffer;
void *src = args->src.info.buffer;
ucc_status_t status = UCC_OK;
size_t send_size;
ucc_rank_t peer;
size_t send_size;
ucc_rank_t peer;

if (rank == args->root) {
send_size = ucc_dt_size(args->src.info.datatype) *
Expand Down Expand Up @@ -816,12 +867,25 @@ ucc_status_t ucc_tl_nccl_scatter_start(ucc_coll_task_t *coll_task)
stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team));
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
} else {
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclRecv(dst, send_size, ncclChar,
args->root, team->nccl_comm,
stream),
exit_coll, status, UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclRecv(dst, send_size, ncclChar, args->root,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
}
task->super.status = UCC_INPROGRESS;
status = ucc_tl_nccl_collective_sync(task, stream);
Expand All @@ -847,8 +911,8 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task)
void *dst = args->dst.info.buffer;
void *src = args->src.info_v.buffer;
ucc_status_t status = UCC_OK;
size_t count, displ, dt_size;
ucc_rank_t peer;
size_t count, displ, dt_size;
ucc_rank_t peer;

if (rank == args->root) {
dt_size = ucc_dt_size(args->src.info_v.datatype);
Expand Down Expand Up @@ -885,12 +949,25 @@ ucc_status_t ucc_tl_nccl_scatterv_start(ucc_coll_task_t *coll_task)
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
}
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team));
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
} else {
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclRecv(dst, args->dst.info.count * dt_size,
ncclChar, args->root,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclRecv(dst, args->dst.info.count * dt_size, ncclChar,
args->root, team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
}
task->super.status = UCC_INPROGRESS;
status = ucc_tl_nccl_collective_sync(task, stream);
Expand Down
Loading

0 comments on commit 6239705

Please sign in to comment.