Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/NCCL: lazy init nccl comm #851

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/components/tl/nccl/tl_nccl.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = {
UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names)
},

{"BLOCKING", "1",
"If set to 0 will use non-blocking mode communicator behavior, "
"if set to 1 will use blocking mode",
{"BLOCKING", "yes",
"If set to no will use non-blocking mode communicator behavior, "
"if set to yes will use blocking mode",
ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking),
UCS_CONFIG_TYPE_BOOL},

{"LAZY_INIT", "yes",
"Initialize NCCL communicator on first collective",
ucc_offsetof(ucc_tl_nccl_context_config_t, nccl_lazy_init),
UCC_CONFIG_TYPE_BOOL},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t,
Expand Down
14 changes: 13 additions & 1 deletion src/components/tl/nccl/tl_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

enum {
TL_NCCL_COMM_STATE_ERROR,
TL_NCCL_COMM_STATE_OOB,
TL_NCCL_COMM_STATE_INIT_TEAM,
TL_NCCL_COMM_STATE_INIT_COMM,
TL_NCCL_COMM_STATE_DESTROY_COMM,
TL_NCCL_COMM_STATE_READY,
};

typedef struct ucc_tl_nccl_iface {
ucc_tl_iface_t super;
} ucc_tl_nccl_iface_t;
Expand All @@ -66,6 +75,7 @@ typedef struct ucc_tl_nccl_context_config {
ucc_tl_context_config_t super;
ucc_tl_nccl_completion_sync_type_t sync_type;
int nccl_cfg_blocking;
int nccl_lazy_init;
} ucc_tl_nccl_context_config_t;

typedef struct ucc_tl_nccl_lib {
Expand All @@ -85,7 +95,7 @@ UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *,

typedef struct ucc_tl_nccl_team {
ucc_tl_team_t super;
ucc_status_t comm_state;
int comm_state;
ncclUniqueId *unique_id;
void *oob_req;
ncclComm_t nccl_comm;
Expand Down Expand Up @@ -146,6 +156,8 @@ static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NO
return UCC_OK;
}

ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team);

#define NCCLCHECK_GOTO(_cmd, _label, _st, _lib, _task_st, _comm, _check_nb) \
do { \
ncclResult_t e = _cmd; \
Expand Down
10 changes: 9 additions & 1 deletion src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_tl_nccl_task_t **coll_task)
{
ucc_tl_nccl_team_t *nccl_team = ucc_derived_of(team, ucc_tl_nccl_team_t);
ucc_tl_nccl_context_t *nccl_ctx = ucc_derived_of(team->context,
ucc_tl_nccl_context_t);
ucc_tl_nccl_task_t *task;
Expand All @@ -143,6 +144,13 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
return UCC_ERR_NOT_SUPPORTED;
}

if (ucc_unlikely(nccl_team->comm_state != TL_NCCL_COMM_STATE_READY)) {
status = ucc_tl_nccl_comm_init(nccl_team);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
}

task = ucc_mpool_get(&nccl_ctx->req_mp);
if (ucc_unlikely(!task)) {
tl_error(team->context->lib, "failed to get task from mpool");
Expand Down Expand Up @@ -206,7 +214,7 @@ ucc_status_t ucc_tl_nccl_coll_finalize(ucc_coll_task_t *coll_task)
ucc_status_t status = UCC_OK;

if (ucc_unlikely(task->super.super.status != UCC_OK)) {
team->comm_state = task->super.super.status;
team->comm_state = TL_NCCL_COMM_STATE_ERROR;
}
tl_debug(UCC_TASK_LIB(task), "finalizing coll task %p", task);
ucc_tl_nccl_free_task(task);
Expand Down
186 changes: 112 additions & 74 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
ucc_tl_nccl_context_t *ctx =
ucc_derived_of(tl_context, ucc_tl_nccl_context_t);
ucc_tl_nccl_context_t *ctx = ucc_derived_of(tl_context,
ucc_tl_nccl_context_t);
ucc_team_oob_coll_t *oob;
ucc_status_t status;
ucc_rank_t size;
UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);

UCC_CLASS_CALL_SUPER_INIT(ucc_tl_team_t, &ctx->super, params);
oob = &(UCC_TL_TEAM_OOB(self));
size = UCC_TL_TEAM_SIZE(self);
self->comm_state = UCC_OK;
self->stream = NULL;
self->nccl_comm = NULL;
self->unique_id = ucc_malloc(sizeof(ncclUniqueId) * (size + 1),
"tl_nccl_unique_id");
if (!self->unique_id) {
Expand All @@ -31,6 +34,7 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
sizeof(ncclUniqueId) * (size + 1));
return UCC_ERR_NO_MEMORY;
}

if (UCC_TL_TEAM_RANK(self) == 0) {
ncclResult_t st;
st = ncclGetUniqueId(&self->unique_id[size]);
Expand All @@ -39,14 +43,16 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
memset(&self->unique_id[size], 0, sizeof(ncclUniqueId));
}
}
status = UCC_TL_TEAM_OOB(self).allgather(
&self->unique_id[size], self->unique_id,
sizeof(ncclUniqueId), UCC_TL_TEAM_OOB(self).coll_info,
&self->oob_req);

status = oob->allgather(&self->unique_id[size],
self->unique_id, sizeof(ncclUniqueId),
oob->coll_info, &self->oob_req);
if (status != UCC_OK) {
tl_error(ctx->super.super.lib, "failed to start oob allgather");
goto free_unique_id;
}
self->comm_state = TL_NCCL_COMM_STATE_OOB;

return UCC_OK;

free_unique_id:
Expand All @@ -69,15 +75,17 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
#if NCCL_USE_NON_BLOCKING
ncclResult_t nccl_status, st;

if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) {
if (team->comm_state == TL_NCCL_COMM_STATE_DESTROY_COMM) {
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
goto check_finalize;
}
#endif

if (team->stream) {
cudaStreamDestroy(team->stream);
team->stream = NULL;
}
if (team->nccl_comm) {
if (team->comm_state != UCC_OK && team->comm_state != UCC_INPROGRESS) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) {
ncclCommAbort(team->nccl_comm);
} else {
#if NCCL_USE_NON_BLOCKING
Expand All @@ -91,7 +99,7 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
ncclCommAbort(team->nccl_comm);
return UCC_ERR_NO_MESSAGE;
} else if (nccl_status == ncclInProgress) {
team->comm_state = UCC_INPROGRESS;
team->comm_state = TL_NCCL_COMM_STATE_DESTROY_COMM;
return UCC_INPROGRESS;
} else {
ncclCommDestroy(team->nccl_comm);
Expand All @@ -101,95 +109,125 @@ ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
ncclCommDestroy(team->nccl_comm);
#endif
}
cudaStreamDestroy(team->stream);
}

UCC_CLASS_DELETE_FUNC_NAME(ucc_tl_nccl_team_t)(tl_team);
return UCC_OK;
}

ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
ucc_status_t ucc_tl_nccl_comm_init(ucc_tl_nccl_team_t *team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_status_t status;
ncclResult_t nccl_status;
ncclUniqueId errorid;

#if NCCL_USE_NON_BLOCKING
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;
ncclResult_t st;

if (team->comm_state == UCC_INPROGRESS) {
goto ncclInitStage;
}
ncclResult_t async_status;
#endif

status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
}
if (status != UCC_OK) {
UCC_TL_TEAM_OOB(team).req_free(team->oob_req);
tl_error(tl_team->context->lib, "oob req test failed");
goto free_unique_id;
}
status = UCC_TL_TEAM_OOB(team).req_free(team->oob_req);
if (status != UCC_OK) {
tl_error(tl_team->context->lib, "oob req free failed");
goto free_unique_id;
}
/* check unique id is valid */
memset(&errorid, 0, sizeof(errorid));
if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) {
tl_error(tl_team->context->lib, "incorrect unique id");
goto free_unique_id;
if (team->comm_state == TL_NCCL_COMM_STATE_READY) {
return UCC_OK;
} else if (team->comm_state == TL_NCCL_COMM_STATE_ERROR) {
return UCC_ERR_NOT_SUPPORTED;
} else if (team->comm_state == TL_NCCL_COMM_STATE_INIT_COMM) {
#if NCCL_USE_NON_BLOCKING
goto nccl_async_init;
#else
ucc_assert_always(0);
#endif
}

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
cudaStreamNonBlocking),
exit_err, status);
#if NCCL_USE_NON_BLOCKING
nccl_cfg.blocking = UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking;
nccl_status = ncclCommInitRankConfig(&team->nccl_comm,
UCC_TL_TEAM_SIZE(team),
team->unique_id[0],
UCC_TL_TEAM_RANK(team),
&nccl_cfg);
if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) {
goto free_stream;
/*
* if NCCL comm initialized during first call to collective init a.k.a lazy init
* we need to use blocking init to correctly fallback to other TL in case of error
*/
nccl_cfg.blocking = (UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking ||
UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_lazy_init) ? 1: 0;

nccl_status = ncclCommInitRankConfig(&team->nccl_comm, tsize,
team->unique_id[0], trank, &nccl_cfg);
if ((nccl_status != ncclInProgress) && (nccl_status != ncclSuccess)) {
goto nccl_comm_init_err;
}
ncclInitStage:
st = ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (st != ncclSuccess) {
nccl_status = st;
nccl_async_init:
nccl_status = ncclCommGetAsyncError(team->nccl_comm, &async_status);
if (nccl_status != ncclSuccess) {
goto nccl_comm_init_err;
}
if (nccl_status == ncclInProgress){
team->comm_state = UCC_INPROGRESS;
return UCC_INPROGRESS;
if (async_status == ncclInProgress) {
team->comm_state = TL_NCCL_COMM_STATE_INIT_COMM;
}
#else
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));
#endif
nccl_status = ncclCommInitRank(&team->nccl_comm, tsize, team->unique_id[0],
trank);
if (nccl_status != ncclSuccess) {
goto free_stream;
goto nccl_comm_init_err;
}
ucc_free(team->unique_id);
tl_debug(tl_team->context->lib, "initialized tl team: %p", team);
#endif

team->comm_state = TL_NCCL_COMM_STATE_READY;
return UCC_OK;

free_stream:
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
#if NCCL_USE_NON_BLOCKING
ncclCommAbort(team->nccl_comm);
#endif
cudaStreamDestroy(team->stream);
free_unique_id:
ucc_free(team->unique_id);
nccl_comm_init_err:
tl_debug(team->super.super.context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
if (nccl_status == ncclInvalidUsage) {
/*
* handles the case when trying to inititize multiple ranks
* on the same GPU. Return "not supported" and fallback to other TL
*/
status = UCC_ERR_NOT_SUPPORTED;
} else {
status = UCC_ERR_NO_RESOURCE;
}
team->comm_state = TL_NCCL_COMM_STATE_ERROR;

exit_err:
return status;
}

ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);
ucc_team_oob_coll_t *oob = &(UCC_TL_TEAM_OOB(team));
ncclUniqueId errorid;
ucc_status_t status;


if (team->comm_state == TL_NCCL_COMM_STATE_OOB) {
status = oob->req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
}

oob->req_free(team->oob_req);
if (status != UCC_OK) {
tl_error(tl_team->context->lib, "oob req test failed");
return status;
}

/* check unique id is valid */
memset(&errorid, 0, sizeof(errorid));
if (!memcmp(&errorid, team->unique_id, sizeof(errorid))) {
tl_error(tl_team->context->lib, "incorrect unique id");
return status;
}

team->comm_state = TL_NCCL_COMM_STATE_INIT_TEAM;
}

if (UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_lazy_init) {
return UCC_OK;
}

return ucc_tl_nccl_comm_init(team);
}

ucc_status_t ucc_tl_nccl_coll_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h)
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/ucc_tl.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ ucc_status_t ucc_tl_team_create_multiple(ucc_team_multiple_req_t *req)
}
req->descs[*id].status = UCC_TL_CTX_IFACE(req->descs[*id].ctx)
->team.create_test(&req->descs[*id].team->super);
if (req->descs[*id].status < 0) {
/* if team create failed in team create test need to cleanup resources */
UCC_TL_CTX_IFACE(req->descs[*id].ctx)->team.destroy(
samnordmann marked this conversation as resolved.
Show resolved Hide resolved
&req->descs[*id].team->super);
}
return UCC_INPROGRESS;
}

Expand Down
10 changes: 10 additions & 0 deletions src/components/tl/ucc_tl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,18 @@ typedef struct ucc_tl_lib_attr {
#define UCC_TL_TEAM_IFACE(_tl_team) \
(ucc_derived_of((_tl_team)->super.context->lib, ucc_tl_lib_t))->iface

/**
* Get TL team lib
* @param [in] _tl_team pointer to TL team object
* @return pointer to TL lib object
*/
#define UCC_TL_TEAM_LIB(_tl_team) (_tl_team)->super.super.context->lib

/**
* Get TL team context
* @param [in] _tl_team pointer to TL team object
* @return pointer to TL context object
*/
#define UCC_TL_TEAM_CTX(_tl_team) (_tl_team)->super.super.context

#define UCC_TL_CORE_CTX(_tl_team) ((_tl_team)->super.super.context->ucc_context)
Expand Down
Loading