From bca65c8d7841482c3bb93514c8e9d031c0868e82 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Thu, 5 Oct 2023 09:07:34 +0200 Subject: [PATCH] CORE: skip zero size collectives (#787) * CORE: skip zero size collectives * REVIEW: fix review comments --- src/components/tl/nccl/tl_nccl_coll.c | 3 ++ src/core/ucc_coll.c | 78 +++++++++++++++++++-------- src/core/ucc_lib.c | 10 ++++ src/core/ucc_lib.h | 18 ++++--- src/schedule/ucc_schedule.c | 22 ++++++++ src/utils/ucc_coll_utils.c | 7 ++- tools/perf/ucc_pt_benchmark.cc | 8 +-- 7 files changed, 112 insertions(+), 34 deletions(-) diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index fdb1ade4a0..8a225c268b 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -135,6 +135,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, ucc_tl_nccl_context_t); ucc_tl_nccl_task_t *task; ucc_status_t status; + ucc_coll_progress_fn_t progress_fn; if (!ucc_coll_args_is_predefined_dt(&coll_args->args, team->params.rank)) { tl_error(team->context->lib, @@ -147,11 +148,13 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, tl_error(team->context->lib, "failed to get task from mpool"); return UCC_ERR_NO_MEMORY; } + progress_fn = task->super.progress; ucc_coll_task_init(&task->super, coll_args, team); UCC_TL_NCCL_PROFILE_REQUEST_NEW(task, "tl_nccl_task", 0); task->super.finalize = ucc_tl_nccl_coll_finalize; task->super.triggered_post = ucc_tl_nccl_triggered_post; + task->super.progress = progress_fn; task->completed = NULL; if (nccl_ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) { status = ucc_ec_create_event(&task->completed, UCC_EE_CUDA_STREAM); diff --git a/src/core/ucc_coll.c b/src/core/ucc_coll.c index 8a67332a30..8cf3658570 100644 --- a/src/core/ucc_coll.c +++ b/src/core/ucc_coll.c @@ -157,6 +157,15 @@ static ucc_status_t ucc_coll_args_check_mem_type(ucc_coll_args_t *coll_args, }; } +#define UCC_COLL_TYPE_SKIP_ZERO_SIZE \ + (UCC_COLL_TYPE_ALLREDUCE | \ + UCC_COLL_TYPE_ALLGATHER | \ + UCC_COLL_TYPE_ALLTOALL | \ + UCC_COLL_TYPE_BCAST | \ + UCC_COLL_TYPE_GATHER | \ + UCC_COLL_TYPE_REDUCE | \ + UCC_COLL_TYPE_SCATTER) + UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, (coll_args, request, team), ucc_coll_args_t *coll_args, ucc_coll_req_h *request, ucc_team_h team) @@ -167,6 +176,7 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, ucc_ee_executor_params_t params; ucc_memory_type_t coll_mem_type; ucc_ee_type_t coll_ee_type; + size_t coll_size; if (ucc_unlikely(team->state != UCC_TEAM_ACTIVE)) { ucc_error("team %p is used before team create is completed", team); @@ -174,6 +184,26 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, } /* Global check to reduce the amount of checks throughout all TLs */ + + if (UCC_COLL_TYPE_SKIP_ZERO_SIZE & coll_args->coll_type) { + coll_size = ucc_coll_args_msgsize(coll_args, team->rank, team->size); + if (coll_size == 0) { + task = ucc_mpool_get(&team->contexts[0]->lib->stub_tasks_mp); + if (ucc_unlikely(!task)) { + ucc_error("failed to allocate dummy task"); + return UCC_ERR_NO_MEMORY; + } + op_args.mask = 0; + memcpy(&op_args.args, coll_args, sizeof(ucc_coll_args_t)); + op_args.team = team; + op_args.args.flags = 0; + UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args, + UCC_COLL_ARGS_FIELD_FLAGS, flags); + ucc_coll_task_init(task, &op_args, NULL); + goto print_trace; + } + } + if (UCC_COLL_ARGS_ACTIVE_SET(coll_args) && ((UCC_COLL_TYPE_BCAST != coll_args->coll_type) || coll_args->active_set.size != 2)) { @@ -246,6 +276,10 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, } task->seq_num = team->seq_num++; + ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED); + +print_trace: + *request = &task->super; if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DIAG) { char coll_str[256]; ucc_coll_str(task, coll_str, sizeof(coll_str), @@ -263,9 +297,6 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init, } } - ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED); - *request = &task->super; - return UCC_OK; coll_finalize: @@ -297,18 +328,21 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request), ucc_status_t status; if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) { - ucc_rank_t rank = task->team->params.team->rank; - if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) { - if (rank == 0) { + /* team is NULL if task is a dummy task, e.g. collective of zero size */ + if (task->team) { + ucc_rank_t rank = task->team->params.team->rank; + if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) { + if (rank == 0) { + ucc_log_component_collective_trace( + ucc_global_config.coll_trace.log_level, + "coll post: req %p, seq_num %u", task, task->seq_num); + } + } else { ucc_log_component_collective_trace( ucc_global_config.coll_trace.log_level, - "coll post: req %p, seq_num %u", task, task->seq_num); + "coll post: rank %d req %p, seq_num %u", rank, task, + task->seq_num); } - } else { - ucc_log_component_collective_trace( - ucc_global_config.coll_trace.log_level, - "coll post: rank %d req %p, seq_num %u", rank, task, - task->seq_num); } } @@ -388,18 +422,20 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_finalize, (request), ucc_coll_task_t *task = ucc_derived_of(request, ucc_coll_task_t); if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) { - ucc_rank_t rank = task->team->params.team->rank; - if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) { - if (rank == 0) { + if (task->team) { + ucc_rank_t rank = task->team->params.team->rank; + if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) { + if (rank == 0) { + ucc_log_component_collective_trace( + ucc_global_config.coll_trace.log_level, + "coll finalize: req %p, seq_num %u", task, task->seq_num); + } + } else { ucc_log_component_collective_trace( ucc_global_config.coll_trace.log_level, - "coll finalize: req %p, seq_num %u", task, task->seq_num); + "coll finalize: rank %d req %p, seq_num %u", rank, task, + task->seq_num); } - } else { - ucc_log_component_collective_trace( - ucc_global_config.coll_trace.log_level, - "coll finalize: rank %d req %p, seq_num %u", rank, task, - task->seq_num); } } return ucc_collective_finalize_internal(task); diff --git a/src/core/ucc_lib.c b/src/core/ucc_lib.c index 6f909eda08..811b58b7ad 100644 --- a/src/core/ucc_lib.c +++ b/src/core/ucc_lib.c @@ -356,6 +356,14 @@ ucc_status_t ucc_init_version(unsigned api_major_version, goto error; } + status = ucc_mpool_init(&lib->stub_tasks_mp, 0, sizeof(ucc_coll_task_t), 0, + UCC_CACHE_LINE_SIZE, 8, UINT_MAX, + &ucc_coll_task_mpool_ops, UCC_THREAD_MULTIPLE, + "stub_tasks"); + if (status != UCC_OK) { + goto error; + } + *lib_p = lib; return UCC_OK; error: @@ -473,6 +481,8 @@ ucc_status_t ucc_finalize(ucc_lib_info_t *lib) gl_status = UCC_OK; ucc_assert(lib->n_cl_libs_opened > 0); ucc_assert(lib->cl_libs != NULL); + + ucc_mpool_cleanup(&lib->stub_tasks_mp, 1); for (i = 0; i < lib->n_tl_libs_opened; i++) { lib->tl_libs[i]->iface->lib.finalize(&lib->tl_libs[i]->super); } diff --git a/src/core/ucc_lib.h b/src/core/ucc_lib.h index d8beedd91a..3cde2eebf1 100644 --- a/src/core/ucc_lib.h +++ b/src/core/ucc_lib.h @@ -11,6 +11,7 @@ #include "ucc/api/ucc.h" #include "components/cl/ucc_cl_type.h" #include "utils/ucc_parser.h" +#include "utils/ucc_mpool.h" typedef struct ucc_cl_lib ucc_cl_lib_t; typedef struct ucc_tl_lib ucc_tl_lib_t; @@ -26,14 +27,15 @@ typedef struct ucc_lib_config { } ucc_lib_config_t; typedef struct ucc_lib_info { - char *full_prefix; - int n_cl_libs_opened; - int n_tl_libs_opened; - ucc_cl_lib_t **cl_libs; - ucc_tl_lib_t **tl_libs; - ucc_lib_attr_t attr; - int specific_cls_requested; - ucc_cl_lib_attr_t *cl_attrs; + char *full_prefix; + int n_cl_libs_opened; + int n_tl_libs_opened; + ucc_cl_lib_t **cl_libs; + ucc_tl_lib_t **tl_libs; + ucc_lib_attr_t attr; + int specific_cls_requested; + ucc_cl_lib_attr_t *cl_attrs; + ucc_mpool_t stub_tasks_mp; } ucc_lib_info_t; int ucc_tl_is_required(ucc_lib_info_t *lib, ucc_tl_iface_t *tl_iface, diff --git a/src/schedule/ucc_schedule.c b/src/schedule/ucc_schedule.c index bd3d9a408a..777b9c6b39 100644 --- a/src/schedule/ucc_schedule.c +++ b/src/schedule/ucc_schedule.c @@ -82,6 +82,25 @@ void ucc_coll_task_destruct(ucc_coll_task_t *task) } } +ucc_status_t ucc_dummy_post(ucc_coll_task_t *task) +{ + task->status = UCC_OK; + return ucc_task_complete(task); +} + +ucc_status_t ucc_dummy_finalize(ucc_coll_task_t *task) +{ + ucc_mpool_put(task); + return UCC_OK; +} + +/* NOLINTNEXTLINE task argument is not used*/ +void ucc_dummy_progress(ucc_coll_task_t *task) +{ + /* this function should never be called */ + ucc_assert_always(0); +} + ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task, ucc_base_coll_args_t *bargs, ucc_base_team_t *team) @@ -97,6 +116,9 @@ ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task, task->super.status = UCC_OPERATION_INITIALIZED; task->triggered_post_setup = NULL; task->triggered_post = ucc_triggered_post; + task->post = ucc_dummy_post; + task->finalize = ucc_dummy_finalize; + task->progress = ucc_dummy_progress; if (bargs) { memcpy(&task->bargs, bargs, sizeof(*bargs)); } diff --git a/src/utils/ucc_coll_utils.c b/src/utils/ucc_coll_utils.c index 42c9022b3d..3921f1262e 100644 --- a/src/utils/ucc_coll_utils.c +++ b/src/utils/ucc_coll_utils.c @@ -548,7 +548,12 @@ void ucc_coll_str(const ucc_coll_task_t *task, char *str, size_t len, size_t tl_info_len = 0; char task_info[64], cl_info[16], tl_info[32]; - if (task->team->context->lib->log_component.name[0] == 'C') { + if (!task->team) { + /* zero size collective, no CL or TL */ + strncpy(cl_info, "NoOp", sizeof(cl_info)); + strncpy(tl_info, "NoOp", sizeof(tl_info)); + } + else if (task->team->context->lib->log_component.name[0] == 'C') { /* it's not CL BASIC task */ ucc_strncpy_safe(cl_info, task->team->context->lib->log_component.name, diff --git a/tools/perf/ucc_pt_benchmark.cc b/tools/perf/ucc_pt_benchmark.cc index e9e313fb43..c4ef8c6289 100644 --- a/tools/perf/ucc_pt_benchmark.cc +++ b/tools/perf/ucc_pt_benchmark.cc @@ -1,9 +1,5 @@ /** -<<<<<<< HEAD * Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -======= - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. ->>>>>>> REVIEW: fix review comments * * See file LICENSE for terms. */ @@ -115,6 +111,10 @@ ucc_status_t ucc_pt_benchmark::run_bench() noexcept } print_time(cnt, args, time); coll->free_args(args); + if (max_count == 0) { + /* exit from loop when min_count == max_count == 0 */ + break; + } } return UCC_OK;