Skip to content

Commit

Permalink
REVIEW: review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed Jun 27, 2023
1 parent 90bbc66 commit 7184609
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 174 deletions.
38 changes: 17 additions & 21 deletions src/components/tl/nccl/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -70,12 +70,14 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)

task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
if (count != 0) {
for (peer = 0; peer < size; peer++) {
NCCLCHECK_GOTO(ncclSend(sbuf, count * sdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
}
for (peer = 0; peer < size; peer++) {
Expand All @@ -86,16 +88,12 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)
NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(rbuf, displ * rdt_size),
count * rdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
}
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 1);
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand Down Expand Up @@ -150,7 +148,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task)
}
NCCLCHECK_GOTO(ncclAllGather(sbuf, scratch, max_count * rdt_size,
ncclChar, team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
for (peer = 0; peer < size; peer++) {
rcount = ucc_coll_args_get_count(args,
args->dst.info_v.counts, peer);
Expand Down Expand Up @@ -245,7 +244,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
task->super.status = UCC_INPROGRESS;
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
for (peer = 0; peer < size; peer++) {
count = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer);
displ = ucc_coll_args_get_displacement(args,
Expand All @@ -257,15 +257,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
NCCLCHECK_GOTO(ncclBroadcast(sbuf, PTR_OFFSET(rbuf, displ * rdt_size),
count * rdt_size, ncclChar, peer,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 0);
}
#if NCCL_USE_NON_BLOCKING
NCCLCHECK_INPROGRESS_GOTO(ncclGroupEnd(), exit_coll, status,
UCC_TL_TEAM_LIB(team),
task->nccl_progress_st, team->nccl_comm);
#else
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team));
#endif
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team),
&task->nccl_progress_st, team->nccl_comm, 1);
status = ucc_tl_nccl_collective_sync(task, stream);
exit_coll:
return status;
Expand Down
39 changes: 24 additions & 15 deletions src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates. 2021.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -44,6 +44,7 @@
#define UCC_TL_NCCL_PROFILE_REQUEST_FREE UCC_PROFILE_REQUEST_FREE
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

typedef struct ucc_tl_nccl_iface {
ucc_tl_iface_t super;
} ucc_tl_nccl_iface_t;
Expand Down Expand Up @@ -124,23 +125,31 @@ typedef struct ucc_tl_nccl_task {
UCC_CLASS_DECLARE(ucc_tl_nccl_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);

#define NCCLCHECK_GOTO(_cmd, _label, _status, _lib) \
do { \
ncclResult_t e = _cmd; \
if (ncclSuccess != e) { \
tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \
_status = UCC_ERR_NO_MESSAGE; \
goto _label; \
} \
} while (0)
static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NOLINT
ucc_status_t *task_st, // NOLINT
ncclComm_t nccl_comm, //NOLINT
int check_nb) { //NOLINT
#if NCCL_USE_NON_BLOCKING
if (check_nb &&
(*nccl_status == ncclSuccess || *nccl_status == ncclInProgress)) {
ncclResult_t st = ncclCommGetAsyncError(nccl_comm, nccl_status);
if (st != ncclSuccess) {
return UCC_ERR_NO_MESSAGE;
}
if (ncclInProgress == *nccl_status) {
*task_st = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
}
#endif
return UCC_OK;
}

#define NCCLCHECK_INPROGRESS_GOTO(_cmd, _label, _st, _lib, _task_st, _comm) \
#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \
do { \
ncclResult_t e = _cmd; \
ncclCommGetAsyncError(_comm, &e); \
if (ncclInProgress == e) { \
_task_st = UCC_INPROGRESS; \
} else if (ncclSuccess != e) { \
_st = ucc_tl_nccl_check_nb(&e, _task_st, _comm, _check_nb); \
if (_st != UCC_INPROGRESS && ncclSuccess != e) { \
tl_error(_lib, "NCCL error %d %s", e, ncclGetErrorString(e)); \
_st = UCC_ERR_NO_MESSAGE; \
goto _label; \
Expand Down
Loading

0 comments on commit 7184609

Please sign in to comment.