diff --git a/src/components/tl/nccl/tl_nccl.c b/src/components/tl/nccl/tl_nccl.c index 8465c66e20..592a6dfdef 100644 --- a/src/components/tl/nccl/tl_nccl.c +++ b/src/components/tl/nccl/tl_nccl.c @@ -39,6 +39,11 @@ static ucs_config_field_t ucc_tl_nccl_context_config_table[] = { UCS_CONFIG_TYPE_ENUM(ucc_tl_nccl_completion_sync_names) }, + {"NCCL_CFG_BLOCKING", "1", + "If set to 0 will use non-blocking mode, if set to 1 will use blocking", + ucs_offsetof(ucc_tl_nccl_context_config_t, nccl_cfg_blocking), + UCS_CONFIG_TYPE_BOOL}, + {NULL}}; UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_nccl_lib_t, ucc_base_lib_t, diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 458965a84c..06f32c0371 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -65,6 +65,7 @@ typedef enum ucc_tl_nccl_completion_sync_type { typedef struct ucc_tl_nccl_context_config { ucc_tl_context_config_t super; ucc_tl_nccl_completion_sync_type_t sync_type; + int nccl_cfg_blocking; } ucc_tl_nccl_context_config_t; typedef struct ucc_tl_nccl_lib { @@ -159,4 +160,7 @@ static inline ucc_status_t ucc_tl_nccl_check_nb(ncclResult_t *nccl_status, // NO #define UCC_TL_NCCL_TEAM_LIB(_team) \ (ucc_derived_of((_team)->super.super.context->lib, ucc_tl_nccl_lib_t)) +#define UCC_TL_NCCL_TEAM_CTX(_team) \ + (ucc_derived_of((_team)->super.super.context, ucc_tl_nccl_context_t)) + #endif diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index 0497490d83..af2aff2ac6 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -148,7 +148,7 @@ ucc_status_t ucc_tl_nccl_team_create_test(ucc_base_team_t *tl_team) CUDA_CHECK_GOTO(cudaStreamCreateWithFlags(&team->stream, cudaStreamNonBlocking), free_unique_id, status); #if NCCL_USE_NON_BLOCKING - nccl_cfg.blocking = 0; + nccl_cfg.blocking = UCC_TL_NCCL_TEAM_CTX(team)->cfg.nccl_cfg_blocking; nccl_status = ncclCommInitRankConfig(&team->nccl_comm, UCC_TL_TEAM_SIZE(team), team->unique_id[0],