Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/NCCL: make team init non blocking #772

Merged
merged 3 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions src/components/tl/nccl/tl_nccl_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include "coll_score/ucc_coll_score.h"
#include "utils/arch/cuda_def.h"

#define NCCL_VERSION_COMM_INIT_NB NCCL_VERSION(2,14,3)
#define NCCL_USE_NON_BLOCKING NCCL_VERSION_CODE >= NCCL_VERSION_COMM_INIT_NB

UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
const ucc_base_team_params_t *params)
{
Expand Down Expand Up @@ -57,23 +60,53 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_team_t, ucc_base_context_t *tl_context,
UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_team_t)
{
tl_debug(self->super.super.context->lib, "finalizing tl team: %p", self);
if (self->nccl_comm) {
if (self->comm_state != UCC_OK) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(self->nccl_comm);
} else {
ncclCommDestroy(self->nccl_comm);
}
cudaStreamDestroy(self->stream);
}
}

UCC_CLASS_DEFINE_DELETE_FUNC(ucc_tl_nccl_team_t, ucc_base_team_t);
UCC_CLASS_DEFINE(ucc_tl_nccl_team_t, ucc_tl_team_t);

ucc_status_t ucc_tl_nccl_team_destroy(ucc_base_team_t *tl_team)
{
ucc_tl_nccl_team_t *team = ucc_derived_of(tl_team, ucc_tl_nccl_team_t);

#if NCCL_USE_NON_BLOCKING
ncclResult_t nccl_status;

if (team->nccl_comm && team->comm_state == UCC_INPROGRESS) {
goto check_finalize;
}
#endif

if (team->nccl_comm) {
if (team->comm_state != UCC_OK && team->comm_state != UCC_INPROGRESS) {
/* if communication error was detected ncclCommAbort should be used
since ncclCommDestroy could block */
ncclCommAbort(team->nccl_comm);
} else {
#if NCCL_USE_NON_BLOCKING
ncclCommFinalize(team->nccl_comm);
check_finalize:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress) {
team->comm_state = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
if (nccl_status != ncclSuccess) {
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
shimmybalsam marked this conversation as resolved.
Show resolved Hide resolved
ncclGetErrorString(nccl_status));
ncclCommAbort(team->nccl_comm);
return UCC_ERR_NO_MESSAGE;
} else {
ncclCommDestroy(team->nccl_comm);
}
team->comm_state = UCC_OK;
#else
ncclCommDestroy(team->nccl_comm);
#endif
}
cudaStreamDestroy(team->stream);
}

UCC_CLASS_DELETE_FUNC_NAME(ucc_tl_nccl_team_t)(tl_team);
return UCC_OK;
}
Expand All @@ -85,6 +118,14 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)
ncclResult_t nccl_status;
ncclUniqueId errorid;

#if NCCL_USE_NON_BLOCKING
ncclConfig_t nccl_cfg = NCCL_CONFIG_INITIALIZER;

if (team->comm_state == UCC_INPROGRESS) {
goto ncclInitStage;
}
#endif

status = UCC_TL_TEAM_OOB(team).req_test(team->oob_req);
if (status == UCC_INPROGRESS) {
return UCC_INPROGRESS;
Expand All @@ -108,19 +149,40 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team)

CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream,
cudaStreamNonBlocking), free_unique_id, status);
#if NCCL_USE_NON_BLOCKING
nccl_cfg.blocking = 0;
nccl_status = ncclCommInitRankConfig(&team->nccl_comm,
shimmybalsam marked this conversation as resolved.
Show resolved Hide resolved
UCC_TL_TEAM_SIZE(team),
team->unique_id[0],
UCC_TL_TEAM_RANK(team),
&nccl_cfg);
if (nccl_status != ncclInProgress && nccl_status != ncclSuccess) {
goto free_stream;
}
ncclInitStage:
ncclCommGetAsyncError(team->nccl_comm, &nccl_status);
if (nccl_status == ncclInProgress){
team->comm_state = UCC_INPROGRESS;
return UCC_INPROGRESS;
}
#else
nccl_status = ncclCommInitRank(&team->nccl_comm, UCC_TL_TEAM_SIZE(team),
team->unique_id[0], UCC_TL_TEAM_RANK(team));
#endif
if (nccl_status != ncclSuccess) {
tl_debug(tl_team->context->lib, "NCCL error %d %s",
nccl_status, ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
goto free_stream;
}
ucc_free(team->unique_id);
tl_debug(tl_team->context->lib, "initialized tl team: %p", team);
return UCC_OK;

free_stream:
tl_debug(tl_team->context->lib, "NCCL error %d %s", nccl_status,
shimmybalsam marked this conversation as resolved.
Show resolved Hide resolved
ncclGetErrorString(nccl_status));
status = UCC_ERR_NO_MESSAGE;
#if NCCL_USE_NON_BLOCKING
ncclCommAbort(team->nccl_comm);
#endif
cudaStreamDestroy(team->stream);
free_unique_id:
ucc_free(team->unique_id);
Expand Down
8 changes: 3 additions & 5 deletions test/mpi/test_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,9 @@ void UccTestMpi::destroy_team(ucc_test_team_t &team)
ucc_status_t status;

team.free_ee();
while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) {
if (UCC_OK != status) {
std::cerr << "ucc_team_destroy failed\n";
break;
}
while (UCC_INPROGRESS == (status = ucc_team_destroy(team.team))) {}
if (UCC_OK != status) {
std::cerr << "ucc_team_destroy failed\n";
}
if (team.comm != MPI_COMM_WORLD) {
MPI_Comm_free(&team.comm);
Expand Down