Skip to content

Commit

Permalink
REVIEW: code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed May 24, 2023
1 parent 3d946bd commit 6e98b6e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
20 changes: 13 additions & 7 deletions src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ typedef enum ucc_tl_nccl_completion_sync_type {
UCC_TL_NCCL_COMPLETION_SYNC_TYPE_LAST
} ucc_tl_nccl_completion_sync_type_t;

typedef enum ucc_tl_nccl_nb_state {
UCC_TL_NCCL_NB_UNUSED,
UCC_TL_NCCL_NB_INIT_IN_PROGRESS,
UCC_TL_NCCL_NB_FINALIZE_IN_PROGRESS
} ucc_tl_nccl_nb_state_t;

typedef struct ucc_tl_nccl_context_config {
ucc_tl_context_config_t super;
ucc_tl_nccl_completion_sync_type_t sync_type;
Expand All @@ -81,13 +87,13 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,
const ucc_base_config_t *);

typedef struct ucc_tl_nccl_team {
ucc_tl_team_t super;
ucc_status_t comm_state;
ncclUniqueId *unique_id;
void *oob_req;
int nccl_nb_state;
ncclComm_t nccl_comm;
cudaStream_t stream;
ucc_tl_team_t super;
ucc_status_t comm_state;
ncclUniqueId *unique_id;
void *oob_req;
ucc_tl_nccl_nb_state_t nccl_nb_state;
ncclComm_t nccl_comm;
cudaStream_t stream;
} ucc_tl_nccl_team_t;

typedef struct ucc_tl_nccl_task {
Expand Down
33 changes: 14 additions & 19 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
#include "coll_score/ucc_coll_score.h"
#include "utils/arch/cuda_def.h"

#define NCCL_VERSION_COMM_INIT_NONBLOCKING NCCL_VERSION(2,14,3)

enum {
NCCL_NB_UNUSED,
NCCL_NB_INIT_IN_PROGRESS,
NCCL_NB_FINALIZE_IN_PROGRESS
};
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
Expand All @@ -31,9 +26,9 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,

size = UCC_TL_TEAM_SIZE(self);
self->comm_state = UCC_OK;
self->nccl_nb_state = NCCL_NB_UNUSED;
self->nccl_nb_state = UCC_TL_NCCL_NB_UNUSED;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
"tl_nccl_unique_id");
if (!self->unique_id) {
tl_error(ctx->super.super.lib,
"failed to allocate %zd bytes for unique_id array",
Expand Down Expand Up @@ -75,10 +70,10 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);

#if NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NONBLOCKING
#if NCCL_USE_NON_BLOCKING
ncclResult_t nccl_status;

if (team->nccl_nb_state == NCCL_NB_FINALIZE_IN_PROGRESS) {
if (team->nccl_nb_state == UCC_TL_NCCL_NB_FINALIZE_IN_PROGRESS) {
goto check_finalize;
}
#endif
Expand All @@ -89,12 +84,12 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
since ncclCommDestroy could block */
ncclCommAbort(team->nccl_comm);
} else {
#if NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NONBLOCKING
#if NCCL_USE_NON_BLOCKING
ncclCommFinalize(team->nccl_comm);
check_finalize:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress) {
team->nccl_nb_state = NCCL_NB_FINALIZE_IN_PROGRESS;
team->nccl_nb_state = UCC_TL_NCCL_NB_FINALIZE_IN_PROGRESS;
return UCC_INPROGRESS;
}
if (nccl_status != ncclSuccess) {
Expand All @@ -105,7 +100,7 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
} else {
ncclCommDestroy(team->nccl_comm);
}
team->nccl_nb_state = NCCL_NB_UNUSED;
team->nccl_nb_state = UCC_TL_NCCL_NB_UNUSED;
#else
ncclCommDestroy(team->nccl_comm);
#endif
Expand All @@ -124,10 +119,10 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
ncclResult_t nccl_status;
ncclUniqueId errorid;

#if NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NONBLOCKING
#if NCCL_USE_NON_BLOCKING
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;

if (team->nccl_nb_state == NCCL_NB_INIT_IN_PROGRESS) {
if (team->nccl_nb_state == UCC_TL_NCCL_NB_INIT_IN_PROGRESS) {
goto ncclInitStage;
}
#endif
Expand Down Expand Up @@ -155,7 +150,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
#if NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NONBLOCKING
#if NCCL_USE_NON_BLOCKING
nccl_cfg.blocking = 0;
nccl_status = ncclCommInitRankConfig(&team->nccl_comm,
UCC_TL_TEAM_SIZE(team),
Expand All @@ -168,7 +163,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
ncclInitStage:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress){
team->nccl_nb_state = NCCL_NB_INIT_IN_PROGRESS;
team->nccl_nb_state = UCC_TL_NCCL_NB_INIT_IN_PROGRESS;
return UCC_INPROGRESS;
}
#else
Expand All @@ -186,7 +181,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
#if NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NONBLOCKING
#if NCCL_USE_NON_BLOCKING
ncclCommAbort(team->nccl_comm);
#endif
cudaStreamDestroy(team->stream);
Expand Down

0 comments on commit 6e98b6e

Please sign in to comment.