Skip to content

Commit

Permalink
TL/NCCL: make team init non blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed May 9, 2023
1 parent a036a5f commit fc6ca00
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ typedef struct ucc_tl_nccl_team {
ncclUniqueId *unique_id;
void *oob_req;
ncclComm_t nccl_comm;
int nccl_nb_state;
cudaStream_t stream;
} ucc_tl_nccl_team_t;

Expand Down
38 changes: 29 additions & 9 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

size = UCC_TL_TEAM_SIZE(self);
self->comm_state = UCC_OK;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
self->comm_state = UCC_OK;
self->nccl_nb_state = 0;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
if (!self->unique_id) {
tl_error(ctx->super.super.lib,
"failed to allocate %zd bytes for unique_id array",
Expand Down Expand Up @@ -80,11 +81,16 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)

ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;
ucc_status_t status;
ncclResult_t nccl_status;
ncclUniqueId errorid;

if (team->nccl_nb_state) {
goto ncclInitStage;
}

status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
Expand All @@ -108,19 +114,33 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));

nccl_cfg.blocking = 0;
nccl_status = ncclCommInitRankConfig(&team->nccl_comm,
UCC_TL_TEAM_SIZE(team),
team->unique_id[0],
UCC_TL_TEAM_RANK(team),
&nccl_cfg);
if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) {
goto free_stream;
}
ncclInitStage:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress){
team->nccl_nb_state = 1;
return UCC_INPROGRESS;
}
if (nccl_status != ncclSuccess) {
tl_debug(tl_team->context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
goto free_stream;
}
ucc_free(team->unique_id);
tl_debug(tl_team->context->lib, "initialized tl team: %p", team);
return UCC_OK;

free_stream:
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
cudaStreamDestroy(team->stream);
free_unique_id:
ucc_free(team->unique_id);
Expand Down

0 comments on commit fc6ca00

Please sign in to comment.