Skip to content

Commit

Permalink
CORE: skip zero size collectives (#787)
Browse files Browse the repository at this point in the history
* CORE: skip zero size collectives

* REVIEW: fix review comments
  • Loading branch information
Sergei-Lebedev authored Oct 5, 2023
1 parent 2a012c4 commit dc1049b
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 34 deletions.
3 changes: 3 additions & 0 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
ucc_tl_nccl_context_t);
ucc_tl_nccl_task_t *task;
ucc_status_t status;
ucc_coll_progress_fn_t progress_fn;

if (!ucc_coll_args_is_predefined_dt(&coll_args->args, team->params.rank)) {
tl_error(team->context->lib,
Expand All @@ -147,11 +148,13 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
tl_error(team->context->lib, "failed to get task from mpool");
return UCC_ERR_NO_MEMORY;
}
progress_fn = task->super.progress;

ucc_coll_task_init(&task->super, coll_args, team);
UCC_TL_NCCL_PROFILE_REQUEST_NEW(task, "tl_nccl_task", 0);
task->super.finalize = ucc_tl_nccl_coll_finalize;
task->super.triggered_post = ucc_tl_nccl_triggered_post;
task->super.progress = progress_fn;
task->completed = NULL;
if (nccl_ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) {
status = ucc_ec_create_event(&task->completed, UCC_EE_CUDA_STREAM);
Expand Down
78 changes: 57 additions & 21 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ static ucc_status_t ucc_coll_args_check_mem_type(ucc_coll_args_t *coll_args,
};
}

#define UCC_COLL_TYPE_SKIP_ZERO_SIZE \
(UCC_COLL_TYPE_ALLREDUCE | \
UCC_COLL_TYPE_ALLGATHER | \
UCC_COLL_TYPE_ALLTOALL | \
UCC_COLL_TYPE_BCAST | \
UCC_COLL_TYPE_GATHER | \
UCC_COLL_TYPE_REDUCE | \
UCC_COLL_TYPE_SCATTER)

UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
(coll_args, request, team), ucc_coll_args_t *coll_args,
ucc_coll_req_h *request, ucc_team_h team)
Expand All @@ -167,13 +176,34 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
ucc_ee_executor_params_t params;
ucc_memory_type_t coll_mem_type;
ucc_ee_type_t coll_ee_type;
size_t coll_size;

if (ucc_unlikely(team->state != UCC_TEAM_ACTIVE)) {
ucc_error("team %p is used before team create is completed", team);
return UCC_ERR_INVALID_PARAM;
}
/* Global check to reduce the amount of checks throughout
all TLs */

if (UCC_COLL_TYPE_SKIP_ZERO_SIZE & coll_args->coll_type) {
coll_size = ucc_coll_args_msgsize(coll_args, team->rank, team->size);
if (coll_size == 0) {
task = ucc_mpool_get(&team->contexts[0]->lib->stub_tasks_mp);
if (ucc_unlikely(!task)) {
ucc_error("failed to allocate dummy task");
return UCC_ERR_NO_MEMORY;
}
op_args.mask = 0;
memcpy(&op_args.args, coll_args, sizeof(ucc_coll_args_t));
op_args.team = team;
op_args.args.flags = 0;
UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args,
UCC_COLL_ARGS_FIELD_FLAGS, flags);
ucc_coll_task_init(task, &op_args, NULL);
goto print_trace;
}
}

if (UCC_COLL_ARGS_ACTIVE_SET(coll_args) &&
((UCC_COLL_TYPE_BCAST != coll_args->coll_type) ||
coll_args->active_set.size != 2)) {
Expand Down Expand Up @@ -246,6 +276,10 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}
task->seq_num = team->seq_num++;

ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED);

print_trace:
*request = &task->super;
if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DIAG) {
char coll_str[256];
ucc_coll_str(task, coll_str, sizeof(coll_str),
Expand All @@ -263,9 +297,6 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}
}

ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED);
*request = &task->super;

return UCC_OK;

coll_finalize:
Expand Down Expand Up @@ -297,18 +328,21 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request),
ucc_status_t status;

if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
if (rank == 0) {
/* team is NULL if task is a dummy task, e.g. collective of zero size */
if (task->team) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
if (rank == 0) {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll post: req %p, seq_num %u", task, task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll post: req %p, seq_num %u", task, task->seq_num);
"coll post: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll post: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
}

Expand Down Expand Up @@ -388,18 +422,20 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_finalize, (request),
ucc_coll_task_t *task = ucc_derived_of(request, ucc_coll_task_t);

if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
if (rank == 0) {
if (task->team) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
if (rank == 0) {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll finalize: req %p, seq_num %u", task, task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll finalize: req %p, seq_num %u", task, task->seq_num);
"coll finalize: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll finalize: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
}
return ucc_collective_finalize_internal(task);
Expand Down
10 changes: 10 additions & 0 deletions src/core/ucc_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,14 @@ ucc_status_t ucc_init_version(unsigned api_major_version,
goto error;
}

status = ucc_mpool_init(&lib->stub_tasks_mp, 0, sizeof(ucc_coll_task_t), 0,
UCC_CACHE_LINE_SIZE, 8, UINT_MAX,
&ucc_coll_task_mpool_ops, UCC_THREAD_MULTIPLE,
"stub_tasks");
if (status != UCC_OK) {
goto error;
}

*lib_p = lib;
return UCC_OK;
error:
Expand Down Expand Up @@ -473,6 +481,8 @@ ucc_status_t ucc_finalize(ucc_lib_info_t *lib)
gl_status = UCC_OK;
ucc_assert(lib->n_cl_libs_opened > 0);
ucc_assert(lib->cl_libs != NULL);

ucc_mpool_cleanup(&lib->stub_tasks_mp, 1);
for (i = 0; i < lib->n_tl_libs_opened; i++) {
lib->tl_libs[i]->iface->lib.finalize(&lib->tl_libs[i]->super);
}
Expand Down
18 changes: 10 additions & 8 deletions src/core/ucc_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ucc/api/ucc.h"
#include "components/cl/ucc_cl_type.h"
#include "utils/ucc_parser.h"
#include "utils/ucc_mpool.h"

typedef struct ucc_cl_lib ucc_cl_lib_t;
typedef struct ucc_tl_lib ucc_tl_lib_t;
Expand All @@ -26,14 +27,15 @@ typedef struct ucc_lib_config {
} ucc_lib_config_t;

typedef struct ucc_lib_info {
char *full_prefix;
int n_cl_libs_opened;
int n_tl_libs_opened;
ucc_cl_lib_t **cl_libs;
ucc_tl_lib_t **tl_libs;
ucc_lib_attr_t attr;
int specific_cls_requested;
ucc_cl_lib_attr_t *cl_attrs;
char *full_prefix;
int n_cl_libs_opened;
int n_tl_libs_opened;
ucc_cl_lib_t **cl_libs;
ucc_tl_lib_t **tl_libs;
ucc_lib_attr_t attr;
int specific_cls_requested;
ucc_cl_lib_attr_t *cl_attrs;
ucc_mpool_t stub_tasks_mp;
} ucc_lib_info_t;

int ucc_tl_is_required(ucc_lib_info_t *lib, ucc_tl_iface_t *tl_iface,
Expand Down
22 changes: 22 additions & 0 deletions src/schedule/ucc_schedule.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ void ucc_coll_task_destruct(ucc_coll_task_t *task)
}
}

ucc_status_t ucc_dummy_post(ucc_coll_task_t *task)
{
task->status = UCC_OK;
return ucc_task_complete(task);
}

ucc_status_t ucc_dummy_finalize(ucc_coll_task_t *task)
{
ucc_mpool_put(task);
return UCC_OK;
}

/* NOLINTNEXTLINE task argument is not used*/
void ucc_dummy_progress(ucc_coll_task_t *task)
{
/* this function should never be called */
ucc_assert_always(0);
}

ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task,
ucc_base_coll_args_t *bargs,
ucc_base_team_t *team)
Expand All @@ -97,6 +116,9 @@ ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task,
task->super.status = UCC_OPERATION_INITIALIZED;
task->triggered_post_setup = NULL;
task->triggered_post = ucc_triggered_post;
task->post = ucc_dummy_post;
task->finalize = ucc_dummy_finalize;
task->progress = ucc_dummy_progress;
if (bargs) {
memcpy(&task->bargs, bargs, sizeof(*bargs));
}
Expand Down
7 changes: 6 additions & 1 deletion src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,12 @@ void ucc_coll_str(const ucc_coll_task_t *task, char *str, size_t len,
size_t tl_info_len = 0;
char task_info[64], cl_info[16], tl_info[32];

if (task->team->context->lib->log_component.name[0] == 'C') {
if (!task->team) {
/* zero size collective, no CL or TL */
strncpy(cl_info, "NoOp", sizeof(cl_info));
strncpy(tl_info, "NoOp", sizeof(tl_info));
}
else if (task->team->context->lib->log_component.name[0] == 'C') {
/* it's not CL BASIC task */
ucc_strncpy_safe(cl_info,
task->team->context->lib->log_component.name,
Expand Down
8 changes: 4 additions & 4 deletions tools/perf/ucc_pt_benchmark.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
/**
<<<<<<< HEAD
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
=======
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
>>>>>>> REVIEW: fix review comments
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -115,6 +111,10 @@ ucc_status_t ucc_pt_benchmark::run_bench() noexcept
}
print_time(cnt, args, time);
coll->free_args(args);
if (max_count == 0) {
/* exit from loop when min_count == max_count == 0 */
break;
}
}

return UCC_OK;
Expand Down

0 comments on commit dc1049b

Please sign in to comment.