Skip to content

Commit

Permalink
TL/NCCL: lazy init nccl comm
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Oct 12, 2023
1 parent 63782b3 commit 016c3db
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 80 deletions.
11 changes: 8 additions & 3 deletions src/components/tl/nccl/tl_nccl.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names)
},

{"BLOCKING", "1",
"If set to 0 will use non-blocking mode communicator behavior, "
"if set to 1 will use blocking mode",
{"BLOCKING", "yes",
"If set to yes will use non-blocking mode communicator behavior, "
"if set to no will use blocking mode",
ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking),
UCS_CONFIG_TYPE_BOOL},

{"LAZY_INIT", "yes",
"Initialize NCCL communicator on first collective",
ucc_offsetof(ucc_tl_nccl_context_config_t, nccl_lazy_init),
UCC_CONFIG_TYPE_BOOL},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t,
Expand Down
14 changes: 13 additions & 1 deletion src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

enum {
TL_NCCL_COMM_STATE_ERROR,
TL_NCCL_COMM_STATE_OOB,
TL_NCCL_COMM_STATE_INIT_TEAM,
TL_NCCL_COMM_STATE_INIT_COMM,
TL_NCCL_COMM_STATE_DESTROY_COMM,
TL_NCCL_COMM_STATE_READY,
};

typedef struct ucc_tl_nccl_iface {
ucc_tl_iface_t super;
} ucc_tl_nccl_iface_t;
Expand All @@ -66,6 +75,7 @@ typedef struct ucc_tl_nccl_context_config {
ucc_tl_context_config_t super;
ucc_tl_nccl_completion_sync_type_t sync_type;
int nccl_cfg_blocking;
int nccl_lazy_init;
} ucc_tl_nccl_context_config_t;

typedef struct ucc_tl_nccl_lib {
Expand All @@ -85,7 +95,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,

typedef struct ucc_tl_nccl_team {
ucc_tl_team_t super;
ucc_status_t comm_state;
int comm_state;
ncclUniqueId *unique_id;
void *oob_req;
ncclComm_t nccl_comm;
Expand Down Expand Up @@ -146,6 +156,8 @@ static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NO
return UCC_OK;
}

ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team);

#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \
do { \
ncclResult_t e = _cmd; \
Expand Down
8 changes: 7 additions & 1 deletion src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_tl_nccl_task_t **coll_task)
{
ucc_tl_nccl_team_t *nccl_team = ucc_derived_of(team, ucc_tl_nccl_team_t);
ucc_tl_nccl_context_t *nccl_ctx = ucc_derived_of(team->context,
ucc_tl_nccl_context_t);
ucc_tl_nccl_task_t *task;
Expand All @@ -143,6 +144,11 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
return UCC_ERR_NOT_SUPPORTED;
}

status = ucc_tl_nccl_comm_init(nccl_team);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

task = ucc_mpool_get(&nccl_ctx->req_mp);
if (ucc_unlikely(!task)) {
tl_error(team->context->lib, "failed to get task from mpool");
Expand Down Expand Up @@ -206,7 +212,7 @@ ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;

if (ucc_unlikely(task->super.super.status != UCC_OK)) {
team->comm_state = task->super.super.status;
team->comm_state = TL_NCCL_COMM_STATE_ERROR;
}
tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task);
ucc_tl_nccl_free_task(task);
Expand Down
187 changes: 112 additions & 75 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
ucc_tl_nccl_context_t *ctx =
ucc_derived_of(tl_context, ucc_tl_nccl_context_t);
ucc_tl_nccl_context_t *ctx = ucc_derived_of(tl_context,
ucc_tl_nccl_context_t);
ucc_team_oob_coll_t *oob;
ucc_status_t status;
ucc_rank_t size;
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);
oob = &(UCC_TL_TEAM_OOB(self));
size = UCC_TL_TEAM_SIZE(self);
self->comm_state = UCC_OK;
self->stream = NULL;
self->nccl_comm = NULL;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
if (!self->unique_id) {
Expand All @@ -31,6 +34,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
sizeof(ncclUniqueId) * (size + 1));
return UCC_ERR_NO_MEMORY;
}

if (UCC_TL_TEAM_RANK(self) == 0) {
ncclResult_t st;
st = ncclGetUniqueId(&self->unique_id[size]);
Expand All @@ -39,14 +43,16 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
memset(&self->unique_id[size], 0, sizeof(ncclUniqueId));
}
}
status = UCC_TL_TEAM_OOB(self).allgather(
&self->unique_id[size], self->unique_id,
sizeof(ncclUniqueId), UCC_TL_TEAM_OOB(self).coll_info,
&self->oob_req);

status = oob->allgather(&self->unique_id[size],
self->unique_id, sizeof(ncclUniqueId),
oob->coll_info, &self->oob_req);
if (status != UCC_OK) {
tl_error(ctx->super.super.lib, "failed to start oob allgather");
goto free_unique_id;
}
self->comm_state = TL_NCCL_COMM_STATE_OOB;

return UCC_OK;

free_unique_id:
Expand All @@ -69,15 +75,17 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
#if NCCL_USE_NON_BLOCKING
ncclResult_t nccl_status, st;

if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) {
if (team->comm_state == TL_NCCL_COMM_STATE_DESTROY_COMM) {
goto check_finalize;
}
#endif

if (team->stream) {
cudaStreamDestroy(team->stream);
team->stream = NULL;
}
if (team->nccl_comm) {
if (team->comm_state != UCC_OK && team->comm_state != UCC_INPROGRESS) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) {
ncclCommAbort(team->nccl_comm);
} else {
#if NCCL_USE_NON_BLOCKING
Expand All @@ -91,7 +99,7 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
ncclCommAbort(team->nccl_comm);
return UCC_ERR_NO_MESSAGE;
} else if (nccl_status == ncclInProgress) {
team->comm_state = UCC_INPROGRESS;
team->comm_state = TL_NCCL_COMM_STATE_DESTROY_COMM;
return UCC_INPROGRESS;
} else {
ncclCommDestroy(team->nccl_comm);
Expand All @@ -101,95 +109,124 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
ncclCommDestroy(team->nccl_comm);
#endif
}
cudaStreamDestroy(team->stream);
}

UCC_CLASS_DELETE_FUNC_NAME(ucc_tl_nccl_team_t)(tl_team);
return UCC_OK;
}

ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_status_t status;
ncclResult_t nccl_status;
ncclUniqueId errorid;

if (team->comm_state == TL_NCCL_COMM_STATE_READY) {
return UCC_OK;
} else if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) {
return UCC_ERR_NOT_SUPPORTED;
} else if (team->comm_state == TL_NCCL_COMM_STATE_INIT_COMM) {
#if NCCL_USE_NON_BLOCKING
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;
ncclResult_t st;

if (team->comm_state == UCC_INPROGRESS) {
goto ncclInitStage;
}
goto nccl_async_init;
#else
ucc_assert_always(0);
#endif

status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
}
if (status != UCC_OK) {
UCC_TL_TEAM_OOB(team).req_free(team->oob_req);
tl_error(tl_team->context->lib, "oob req test failed");
goto free_unique_id;
}
status = UCC_TL_TEAM_OOB(team).req_free(team->oob_req);
if (status != UCC_OK) {
tl_error(tl_team->context->lib, "oob req free failed");
goto free_unique_id;
}
/* check unique id is valid */
memset(&errorid, 0, sizeof(errorid));
if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) {
tl_error(tl_team->context->lib, "incorrect unique id");
goto free_unique_id;
}

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
cudaStreamNonBlocking),
exit_err, status);
#if NCCL_USE_NON_BLOCKING
nccl_cfg.blocking = UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking;
nccl_status = ncclCommInitRankConfig(&team->nccl_comm,
UCC_TL_TEAM_SIZE(team),
team->unique_id[0],
UCC_TL_TEAM_RANK(team),
&nccl_cfg);
if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) {
goto free_stream;
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;
ncclResult_t async_status;

/*
* if NCCL comm initialized during first call to collective init a.k.a lazy init
* we need to use blocking init to correctly fallback to other TL in case of error
*/
nccl_cfg.blocking = (UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking ||
UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_lazy_init) ? 1: 0;

nccl_status = ncclCommInitRankConfig(&team->nccl_comm, tsize,
team->unique_id[0], trank, &nccl_cfg);
if ((nccl_status != ncclInProgress) && (nccl_status != ncclSuccess)) {
goto nccl_comm_init_err;
}
ncclInitStage:
st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (st != ncclSuccess) {
nccl_status = st;
nccl_async_init:
nccl_status = ncclCommGetAsyncError(team->nccl_comm, &async_status);
if (nccl_status != ncclSuccess) {
goto nccl_comm_init_err;
}
if (nccl_status == ncclInProgress){
team->comm_state = UCC_INPROGRESS;
return UCC_INPROGRESS;
if (async_status == ncclInProgress) {
team->comm_state = TL_NCCL_COMM_STATE_INIT_COMM;
}
#else
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));
#endif
nccl_status = ncclCommInitRank(&team->nccl_comm, tsize, team->unique_id[0],
trank);
if (nccl_status != ncclSuccess) {
goto free_stream;
goto nccl_comm_init_err;
}
ucc_free(team->unique_id);
tl_debug(tl_team->context->lib, "initialized tl team: %p", team);
#endif

team->comm_state = TL_NCCL_COMM_STATE_READY;
return UCC_OK;

free_stream:
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
#if NCCL_USE_NON_BLOCKING
ncclCommAbort(team->nccl_comm);
#endif
cudaStreamDestroy(team->stream);
free_unique_id:
ucc_free(team->unique_id);
nccl_comm_init_err:
tl_debug(team->super.super.context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
if (nccl_status == ncclInvalidUsage) {
/*
* handles the case when trying to inititize multiple ranks
* on the same GPU. Return "not supported" and fallback to other TL
*/
status = UCC_ERR_NOT_SUPPORTED;
} else {
status = UCC_ERR_NO_RESOURCE;
}
team->comm_state = TL_NCCL_COMM_STATE_ERROR;

exit_err:
return status;
}

ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_team_oob_coll_t *oob = &(UCC_TL_TEAM_OOB(team));
ncclUniqueId errorid;
ucc_status_t status;


if (team->comm_state == TL_NCCL_COMM_STATE_OOB) {
status = oob->req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
}

oob->req_free(team->oob_req);
if (status != UCC_OK) {
tl_error(tl_team->context->lib, "oob req test failed");
return status;
}

/* check unique id is valid */
memset(&errorid, 0, sizeof(errorid));
if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) {
tl_error(tl_team->context->lib, "incorrect unique id");
return status;
}

team->comm_state = TL_NCCL_COMM_STATE_INIT_TEAM;
}

if (UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_lazy_init) {
return UCC_OK;
}

return ucc_tl_nccl_comm_init(team);
}

ucc_status_t ucc_tl_nccl_coll_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucc_tl.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ ucc_status_t ucc_tl_team_create_multiple(ucc_team_multiple_req_t *req)
}
req->descs[*id].status = UCC_TL_CTX_IFACE(req->descs[*id].ctx)
->team.create_test(&req->descs[*id].team->super);
if (req->descs[*id].status < 0) {
/* if team create failed in team create test need to cleanup resources */
UCC_TL_CTX_IFACE(req->descs[*id].ctx)->team.destroy(
&req->descs[*id].team->super);
}
return UCC_INPROGRESS;
}

Expand Down
10 changes: 10 additions & 0 deletions src/components/tl/ucc_tl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,18 @@ typedef struct ucc_tl_lib_attr {
#define UCC_TL_TEAM_IFACE(_tl_team) \
(ucc_derived_of((_tl_team)->super.context->lib, ucc_tl_lib_t))->iface

/**
* Get TL team lib
* @param [in] _tl_team pointer to TL team object
* @return pointer to TL lib object
*/
#define UCC_TL_TEAM_LIB(_tl_team) (_tl_team)->super.super.context->lib

/**
* Get TL team context
* @param [in] _tl_team pointer to TL team object
* @return pointer to TL context object
*/
#define UCC_TL_TEAM_CTX(_tl_team) (_tl_team)->super.super.context

#define UCC_TL_CORE_CTX(_tl_team) ((_tl_team)->super.super.context->ucc_context)
Expand Down

0 comments on commit 016c3db

Please sign in to comment.