Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CORE: skip zero size collectives #787

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
ucc_tl_nccl_context_t);
ucc_tl_nccl_task_t *task;
ucc_status_t status;
ucc_coll_progress_fn_t progress_fn;

if (!ucc_coll_args_is_predefined_dt(&coll_args->args, team->params.rank)) {
tl_error(team->context->lib,
Expand All @@ -147,11 +148,13 @@ ucc_status_t ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args,
tl_error(team->context->lib, "failed to get task from mpool");
return UCC_ERR_NO_MEMORY;
}
progress_fn = task->super.progress;

ucc_coll_task_init(&task->super, coll_args, team);
UCC_TL_NCCL_PROFILE_REQUEST_NEW(task, "tl_nccl_task", 0);
task->super.finalize = ucc_tl_nccl_coll_finalize;
task->super.triggered_post = ucc_tl_nccl_triggered_post;
task->super.progress = progress_fn;
task->completed = NULL;
if (nccl_ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) {
status = ucc_ec_create_event(&task->completed, UCC_EE_CUDA_STREAM);
Expand Down
78 changes: 57 additions & 21 deletions src/core/ucc_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ static ucc_status_t ucc_coll_args_check_mem_type(ucc_coll_args_t *coll_args,
};
}

#define UCC_COLL_TYPE_SKIP_ZERO_SIZE \
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
(UCC_COLL_TYPE_ALLREDUCE | \
UCC_COLL_TYPE_ALLGATHER | \
UCC_COLL_TYPE_ALLTOALL | \
UCC_COLL_TYPE_BCAST | \
UCC_COLL_TYPE_GATHER | \
UCC_COLL_TYPE_REDUCE | \
UCC_COLL_TYPE_SCATTER)

UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
(coll_args, request, team), ucc_coll_args_t *coll_args,
ucc_coll_req_h *request, ucc_team_h team)
Expand All @@ -167,13 +176,34 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
ucc_ee_executor_params_t params;
ucc_memory_type_t coll_mem_type;
ucc_ee_type_t coll_ee_type;
size_t coll_size;

if (ucc_unlikely(team->state != UCC_TEAM_ACTIVE)) {
ucc_error("team %p is used before team create is completed", team);
return UCC_ERR_INVALID_PARAM;
}
/* Global check to reduce the amount of checks throughout
all TLs */

if (UCC_COLL_TYPE_SKIP_ZERO_SIZE & coll_args->coll_type) {
coll_size = ucc_coll_args_msgsize(coll_args, team->rank, team->size);
if (coll_size == 0) {
task = ucc_mpool_get(&team->contexts[0]->lib->stub_tasks_mp);
if (ucc_unlikely(!task)) {
ucc_error("failed to allocate dummy task");
return UCC_ERR_NO_MEMORY;
}
op_args.mask = 0;
memcpy(&op_args.args, coll_args, sizeof(ucc_coll_args_t));
op_args.team = team;
op_args.args.flags = 0;
UCC_COPY_PARAM_BY_FIELD(&op_args.args, coll_args,
UCC_COLL_ARGS_FIELD_FLAGS, flags);
ucc_coll_task_init(task, &op_args, NULL);
goto print_trace;
}
}

if (UCC_COLL_ARGS_ACTIVE_SET(coll_args) &&
((UCC_COLL_TYPE_BCAST != coll_args->coll_type) ||
coll_args->active_set.size != 2)) {
Expand Down Expand Up @@ -246,6 +276,10 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}
task->seq_num = team->seq_num++;

ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED);

print_trace:
*request = &task->super;
if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DIAG) {
char coll_str[256];
ucc_coll_str(task, coll_str, sizeof(coll_str),
Expand All @@ -263,9 +297,6 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_init,
}
}

ucc_assert(task->super.status == UCC_OPERATION_INITIALIZED);
*request = &task->super;

return UCC_OK;

coll_finalize:
Expand Down Expand Up @@ -297,18 +328,21 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_post, (request),
ucc_status_t status;

if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
if (rank == 0) {
/* team is NULL if task is a dummy task, e.g. collective of zero size */
if (task->team) {
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
if (rank == 0) {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll post: req %p, seq_num %u", task, task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll post: req %p, seq_num %u", task, task->seq_num);
"coll post: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll post: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
}

Expand Down Expand Up @@ -388,18 +422,20 @@ UCC_CORE_PROFILE_FUNC(ucc_status_t, ucc_collective_finalize, (request),
ucc_coll_task_t *task = ucc_derived_of(request, ucc_coll_task_t);

if (ucc_global_config.coll_trace.log_level >= UCC_LOG_LEVEL_DEBUG) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
if (rank == 0) {
if (task->team) {
ucc_rank_t rank = task->team->params.team->rank;
if (ucc_global_config.coll_trace.log_level == UCC_LOG_LEVEL_DEBUG) {
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
if (rank == 0) {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll finalize: req %p, seq_num %u", task, task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll finalize: req %p, seq_num %u", task, task->seq_num);
"coll finalize: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
} else {
ucc_log_component_collective_trace(
ucc_global_config.coll_trace.log_level,
"coll finalize: rank %d req %p, seq_num %u", rank, task,
task->seq_num);
}
}
return ucc_collective_finalize_internal(task);
Expand Down
10 changes: 10 additions & 0 deletions src/core/ucc_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,14 @@ ucc_status_t ucc_init_version(unsigned api_major_version,
goto error;
}

status = ucc_mpool_init(&lib->stub_tasks_mp, 0, sizeof(ucc_coll_task_t), 0,
UCC_CACHE_LINE_SIZE, 8, UINT_MAX,
Sergei-Lebedev marked this conversation as resolved.
Show resolved Hide resolved
&ucc_coll_task_mpool_ops, UCC_THREAD_MULTIPLE,
"stub_tasks");
if (status != UCC_OK) {
goto error;
}

*lib_p = lib;
return UCC_OK;
error:
Expand Down Expand Up @@ -473,6 +481,8 @@ ucc_status_t ucc_finalize(ucc_lib_info_t *lib)
gl_status = UCC_OK;
ucc_assert(lib->n_cl_libs_opened > 0);
ucc_assert(lib->cl_libs != NULL);

ucc_mpool_cleanup(&lib->stub_tasks_mp, 1);
for (i = 0; i < lib->n_tl_libs_opened; i++) {
lib->tl_libs[i]->iface->lib.finalize(&lib->tl_libs[i]->super);
}
Expand Down
18 changes: 10 additions & 8 deletions src/core/ucc_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ucc/api/ucc.h"
#include "components/cl/ucc_cl_type.h"
#include "utils/ucc_parser.h"
#include "utils/ucc_mpool.h"

typedef struct ucc_cl_lib ucc_cl_lib_t;
typedef struct ucc_tl_lib ucc_tl_lib_t;
Expand All @@ -26,14 +27,15 @@ typedef struct ucc_lib_config {
} ucc_lib_config_t;

typedef struct ucc_lib_info {
char *full_prefix;
int n_cl_libs_opened;
int n_tl_libs_opened;
ucc_cl_lib_t **cl_libs;
ucc_tl_lib_t **tl_libs;
ucc_lib_attr_t attr;
int specific_cls_requested;
ucc_cl_lib_attr_t *cl_attrs;
char *full_prefix;
int n_cl_libs_opened;
int n_tl_libs_opened;
ucc_cl_lib_t **cl_libs;
ucc_tl_lib_t **tl_libs;
ucc_lib_attr_t attr;
int specific_cls_requested;
ucc_cl_lib_attr_t *cl_attrs;
ucc_mpool_t stub_tasks_mp;
} ucc_lib_info_t;

int ucc_tl_is_required(ucc_lib_info_t *lib, ucc_tl_iface_t *tl_iface,
Expand Down
22 changes: 22 additions & 0 deletions src/schedule/ucc_schedule.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,25 @@ void ucc_coll_task_destruct(ucc_coll_task_t *task)
}
}

ucc_status_t ucc_dummy_post(ucc_coll_task_t *task)
{
task->status = UCC_OK;
return ucc_task_complete(task);
}

ucc_status_t ucc_dummy_finalize(ucc_coll_task_t *task)
{
ucc_mpool_put(task);
return UCC_OK;
}

/* NOLINTNEXTLINE task argument is not used*/
void ucc_dummy_progress(ucc_coll_task_t *task)
{
/* this function should never be called */
ucc_assert_always(0);
}

ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task,
ucc_base_coll_args_t *bargs,
ucc_base_team_t *team)
Expand All @@ -97,6 +116,9 @@ ucc_status_t ucc_coll_task_init(ucc_coll_task_t *task,
task->super.status = UCC_OPERATION_INITIALIZED;
task->triggered_post_setup = NULL;
task->triggered_post = ucc_triggered_post;
task->post = ucc_dummy_post;
task->finalize = ucc_dummy_finalize;
task->progress = ucc_dummy_progress;
if (bargs) {
memcpy(&task->bargs, bargs, sizeof(*bargs));
}
Expand Down
7 changes: 6 additions & 1 deletion src/utils/ucc_coll_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,12 @@ void ucc_coll_str(const ucc_coll_task_t *task, char *str, size_t len,
size_t tl_info_len = 0;
char task_info[64], cl_info[16], tl_info[32];

if (task->team->context->lib->log_component.name[0] == 'C') {
if (!task->team) {
/* zero size collective, no CL or TL */
strncpy(cl_info, "NoOp", sizeof(cl_info));
strncpy(tl_info, "NoOp", sizeof(tl_info));
}
else if (task->team->context->lib->log_component.name[0] == 'C') {
/* it's not CL BASIC task */
ucc_strncpy_safe(cl_info,
task->team->context->lib->log_component.name,
Expand Down
8 changes: 4 additions & 4 deletions tools/perf/ucc_pt_benchmark.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
/**
<<<<<<< HEAD
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
=======
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
>>>>>>> REVIEW: fix review comments
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -115,6 +111,10 @@ ucc_status_t ucc_pt_benchmark::run_bench() noexcept
}
print_time(cnt, args, time);
coll->free_args(args);
if (max_count == 0) {
/* exit from loop when min_count == max_count == 0 */
break;
}
}

return UCC_OK;
Expand Down
Loading