From fb479170f417ed28c931fe881b2e4d32ae0eee87 Mon Sep 17 00:00:00 2001 From: valentin petrov Date: Wed, 1 Feb 2023 14:41:48 +0300 Subject: [PATCH] CL/HIER: bcast 2step algorithm (#620) * CL/HIER: bcast 2step algorithm * TEST: gtest cl_hier_rab allreduce gtest * TEST: bcast cl_hier_2step gtest * CI: add enable-assert to default clang-tidy * CL/HIER: fix oob mem leak * REVIEW: address comments * TEST: bcast gtest "reset" fix * CL/HIER: bcast 2 step persistent fallback --------- Co-authored-by: Valentin Petrov --- .github/workflows/clang-tidy.yaml | 2 +- src/components/cl/hier/Makefile.am | 8 +- src/components/cl/hier/bcast/bcast.c | 17 ++ src/components/cl/hier/bcast/bcast.h | 38 +++ src/components/cl/hier/bcast/bcast_2step.c | 263 +++++++++++++++++++++ src/components/cl/hier/cl_hier.c | 7 + src/components/cl/hier/cl_hier.h | 2 +- src/components/cl/hier/cl_hier_coll.c | 17 +- src/components/cl/hier/cl_hier_coll.h | 3 +- src/components/cl/hier/cl_hier_team.c | 14 +- src/core/ucc_team.h | 5 + test/gtest/coll/test_allreduce.cc | 34 +++ test/gtest/coll/test_bcast.cc | 51 +++- 13 files changed, 447 insertions(+), 14 deletions(-) create mode 100644 src/components/cl/hier/bcast/bcast.c create mode 100644 src/components/cl/hier/bcast/bcast.h create mode 100644 src/components/cl/hier/bcast/bcast_2step.c diff --git a/.github/workflows/clang-tidy.yaml b/.github/workflows/clang-tidy.yaml index cde225b4b0..f10720bee3 100644 --- a/.github/workflows/clang-tidy.yaml +++ b/.github/workflows/clang-tidy.yaml @@ -28,7 +28,7 @@ jobs: - name: Build UCC run: | ./autogen.sh - CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./configure --prefix=/tmp/ucc/install --with-ucx=/tmp/ucx/install + CC=clang-${CLANG_VER} CXX=clang++-${CLANG_VER} ./configure --prefix=/tmp/ucc/install --with-ucx=/tmp/ucx/install --enable-assert bear --cdb /tmp/compile_commands.json make - name: Run clang-tidy run: | diff --git a/src/components/cl/hier/Makefile.am b/src/components/cl/hier/Makefile.am index 0280074e09..c99d72e96f 100644 --- a/src/components/cl/hier/Makefile.am +++ b/src/components/cl/hier/Makefile.am @@ -20,6 +20,11 @@ barrier = \ barrier/barrier.h \ barrier/barrier.c +bcast = \ + bcast/bcast.h \ + bcast/bcast.c \ + bcast/bcast_2step.c + sources = \ cl_hier.h \ cl_hier.c \ @@ -31,7 +36,8 @@ sources = \ $(allreduce) \ $(alltoallv) \ $(alltoall) \ - $(barrier) + $(barrier) \ + $(bcast) module_LTLIBRARIES = libucc_cl_hier.la libucc_cl_hier_la_SOURCES = $(sources) diff --git a/src/components/cl/hier/bcast/bcast.c b/src/components/cl/hier/bcast/bcast.c new file mode 100644 index 0000000000..beff550a3e --- /dev/null +++ b/src/components/cl/hier/bcast/bcast.c @@ -0,0 +1,17 @@ +/** + * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "bcast.h" +#include "../bcast/bcast.h" + +ucc_base_coll_alg_info_t + ucc_cl_hier_bcast_algs[UCC_CL_HIER_BCAST_ALG_LAST + 1] = { + [UCC_CL_HIER_BCAST_ALG_2STEP] = + {.id = UCC_CL_HIER_BCAST_ALG_2STEP, + .name = "2step", + .desc = "intra-node and inter-node bcasts executed in parallel"}, + [UCC_CL_HIER_BCAST_ALG_LAST] = { + .id = 0, .name = NULL, .desc = NULL}}; diff --git a/src/components/cl/hier/bcast/bcast.h b/src/components/cl/hier/bcast/bcast.h new file mode 100644 index 0000000000..eb4193dd84 --- /dev/null +++ b/src/components/cl/hier/bcast/bcast.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#ifndef BCAST_H_ +#define BCAST_H_ +#include "../cl_hier.h" + +enum +{ + UCC_CL_HIER_BCAST_ALG_2STEP, + UCC_CL_HIER_BCAST_ALG_LAST, +}; + +extern ucc_base_coll_alg_info_t + ucc_cl_hier_bcast_algs[UCC_CL_HIER_BCAST_ALG_LAST + 1]; + +#define UCC_CL_HIER_BCAST_DEFAULT_ALG_SELECT_STR "bcast:0-4k:@2step" + +ucc_status_t ucc_cl_hier_bcast_2step_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task); + +static inline int ucc_cl_hier_bcast_alg_from_str(const char *str) +{ + int i; + + for (i = 0; i < UCC_CL_HIER_BCAST_ALG_LAST; i++) { + if (0 == strcasecmp(str, ucc_cl_hier_bcast_algs[i].name)) { + break; + } + } + return i; +} + +#endif diff --git a/src/components/cl/hier/bcast/bcast_2step.c b/src/components/cl/hier/bcast/bcast_2step.c new file mode 100644 index 0000000000..acf4c6c815 --- /dev/null +++ b/src/components/cl/hier/bcast/bcast_2step.c @@ -0,0 +1,263 @@ +/** + * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See file LICENSE for terms. + */ + +#include "bcast.h" +#include "core/ucc_team.h" +#include "../cl_hier_coll.h" + +static ucc_status_t ucc_cl_hier_bcast_2step_start(ucc_coll_task_t *task) +{ + UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_bcast_2step_start", 0); + return ucc_schedule_start(task); +} + +static ucc_status_t ucc_cl_hier_bcast_2step_finalize(ucc_coll_task_t *task) +{ + ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t); + ucc_status_t status; + + UCC_CL_HIER_PROFILE_REQUEST_EVENT(task, "cl_hier_bcast_2step_finalize", + 0); + status = ucc_schedule_finalize(task); + ucc_cl_hier_put_schedule(schedule); + return status; +} + +static ucc_status_t +ucc_cl_hier_bcast_2step_schedule_finalize(ucc_coll_task_t *task) +{ + ucc_cl_hier_schedule_t *schedule = + ucc_derived_of(task, ucc_cl_hier_schedule_t); + ucc_status_t status; + + status = ucc_schedule_pipelined_finalize(&schedule->super.super.super); + ucc_cl_hier_put_schedule(&schedule->super.super); + return status; +} + +static ucc_status_t +ucc_cl_hier_bcast_2step_frag_setup(ucc_schedule_pipelined_t *schedule_p, + ucc_schedule_t *frag, int frag_num) +{ + ucc_coll_args_t *args = &schedule_p->super.super.bargs.args; + size_t dt_size = ucc_dt_size(args->src.info.datatype); + int n_frags = schedule_p->super.n_tasks; + size_t frag_count, frag_offset; + ucc_coll_task_t *task; + int i; + + frag_count = + ucc_buffer_block_count(args->src.info.count, n_frags, frag_num); + frag_offset = + ucc_buffer_block_offset(args->src.info.count, n_frags, frag_num); + + for (i = 0; i < frag->n_tasks; i++) { + task = frag->tasks[i]; + task->bargs.args.src.info.count = frag_count; + task->bargs.args.src.info.buffer = + PTR_OFFSET(args->src.info.buffer, frag_offset * dt_size); + } + return UCC_OK; +} + +static inline ucc_rank_t +find_root_net_rank(ucc_host_id_t root_host_id, ucc_cl_hier_team_t *cl_team) +{ + ucc_sbgp_t *sbgp = cl_team->sbgps[UCC_HIER_SBGP_NODE_LEADERS].sbgp; + ucc_team_t *core_team = cl_team->super.super.params.team; + ucc_rank_t i, rank; + + for (i = 0; i < sbgp->group_size; i++) { + rank = ucc_ep_map_eval(sbgp->map, i); + if (ucc_team_rank_host_id(rank, core_team) == root_host_id) { + return i; + } + } + return UCC_RANK_INVALID; +} + +static inline ucc_rank_t +find_root_node_rank(ucc_rank_t root, ucc_cl_hier_team_t *cl_team) +{ + ucc_sbgp_t *sbgp = cl_team->sbgps[UCC_HIER_SBGP_NODE].sbgp; + ucc_rank_t i; + + for (i = 0; i < sbgp->group_size; i++) { + if (ucc_ep_map_eval(sbgp->map, i) == root) { + return i; + } + } + return UCC_RANK_INVALID; +} + +static ucc_status_t +ucc_cl_hier_bcast_2step_init_schedule(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_schedule_t **sched_p, int n_frags) +{ + ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); + ucc_team_t *core_team = team->params.team; + ucc_coll_task_t *tasks[2] = {NULL, NULL}; + ucc_rank_t root = coll_args->args.root; + ucc_rank_t rank = UCC_TL_TEAM_RANK(cl_team); + int root_on_local_node = ucc_team_ranks_on_same_node(root, rank, + core_team); + ucc_base_coll_args_t args = *coll_args; + int n_tasks = 0; + int first_task = 0; + ucc_schedule_t *schedule; + ucc_status_t status; + int i; + + schedule = &ucc_cl_hier_get_schedule(cl_team)->super.super; + if (ucc_unlikely(!schedule)) { + return UCC_ERR_NO_MEMORY; + } + status = ucc_schedule_init(schedule, &args, team); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + + if (n_frags > 1) { + args.max_frag_count = + ucc_buffer_block_count(args.args.src.info.count, n_frags, 0); + args.mask |= UCC_BASE_CARGS_MAX_FRAG_COUNT; + } + + ucc_assert(SBGP_ENABLED(cl_team, NODE_LEADERS) || + SBGP_ENABLED(cl_team, NODE)); + if (SBGP_ENABLED(cl_team, NODE_LEADERS)) { + args.args.root = find_root_net_rank( + ucc_team_rank_host_id(root, core_team), cl_team); + status = ucc_coll_init(SCORE_MAP(cl_team, NODE_LEADERS), &args, + &tasks[n_tasks]); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + n_tasks++; + if (root_on_local_node && (root != rank)) { + first_task = 1; + } + } + + if (SBGP_ENABLED(cl_team, NODE)) { + args.args.root = root_on_local_node + ? find_root_node_rank(root, cl_team) + : core_team->topo->node_leader_rank_id; + status = + ucc_coll_init(SCORE_MAP(cl_team, NODE), &args, &tasks[n_tasks]); + if (ucc_unlikely(UCC_OK != status)) { + goto out; + } + n_tasks++; + } + + ucc_task_subscribe_dep(&schedule->super, tasks[first_task], + UCC_EVENT_SCHEDULE_STARTED); + ucc_schedule_add_task(schedule, tasks[first_task]); + if (n_tasks > 1) { + if (root == rank) { + ucc_task_subscribe_dep(&schedule->super, tasks[(first_task + 1) % 2], + UCC_EVENT_SCHEDULE_STARTED); + } else { + ucc_task_subscribe_dep(tasks[first_task], + tasks[(first_task + 1) % 2], UCC_EVENT_COMPLETED); + } + ucc_schedule_add_task(schedule, tasks[(first_task + 1) % 2]); + } + + schedule->super.post = ucc_cl_hier_bcast_2step_start; + schedule->super.progress = NULL; + schedule->super.finalize = ucc_cl_hier_bcast_2step_finalize; + schedule->super.triggered_post = ucc_triggered_post; + *sched_p = schedule; + return UCC_OK; + +out: + for (i = 0; i < n_tasks; i++) { + tasks[i]->finalize(tasks[i]); + } + ucc_cl_hier_put_schedule(schedule); + return status; +} + +static ucc_status_t ucc_cl_hier_bcast_2step_frag_init( + ucc_base_coll_args_t *coll_args, ucc_schedule_pipelined_t *sp, + ucc_base_team_t *team, ucc_schedule_t **frag_p) +{ + int n_frags = sp->super.n_tasks; + + return ucc_cl_hier_bcast_2step_init_schedule(coll_args, team, frag_p, + n_frags); +} + +static ucc_status_t ucc_cl_hier_2step_bcast_start(ucc_coll_task_t *task) +{ + ucc_schedule_pipelined_t *schedule = + ucc_derived_of(task, ucc_schedule_pipelined_t); + + cl_debug(task->team->context->lib, + "posting 2step bcast, buf %p, count %zd, dt %s" + " pdepth %d, frags_total %d", + task->bargs.args.src.info.buffer, + task->bargs.args.src.info.count, + ucc_datatype_str(task->bargs.args.src.info.datatype), + schedule->n_frags, schedule->super.n_tasks); + + return ucc_schedule_pipelined_post(task); +} + +UCC_CL_HIER_PROFILE_FUNC(ucc_status_t, ucc_cl_hier_bcast_2step_init, + (coll_args, team, task), + ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, + ucc_coll_task_t **task) +{ + ucc_cl_hier_team_t *cl_team = ucc_derived_of(team, ucc_cl_hier_team_t); + ucc_cl_hier_lib_config_t *cfg = &UCC_CL_HIER_TEAM_LIB(cl_team)->cfg; + ucc_cl_hier_schedule_t * schedule; + int n_frags, pipeline_depth; + ucc_status_t status; + + if (UCC_IS_PERSISTENT(coll_args->args)) { + return UCC_ERR_NOT_SUPPORTED; + } + ucc_pipeline_nfrags_pdepth(&cfg->bcast_2step_pipeline, + coll_args->args.src.info.count * + ucc_dt_size(coll_args->args.src.info.datatype), + &n_frags, &pipeline_depth); + + if (n_frags == 1) { + return ucc_cl_hier_bcast_2step_init_schedule( + coll_args, team, (ucc_schedule_t **)task, n_frags); + } + + schedule = ucc_cl_hier_get_schedule(cl_team); + if (ucc_unlikely(!schedule)) { + return UCC_ERR_NO_MEMORY; + } + + status = ucc_schedule_pipelined_init( + coll_args, team, ucc_cl_hier_bcast_2step_frag_init, + ucc_cl_hier_bcast_2step_frag_setup, pipeline_depth, n_frags, + cfg->bcast_2step_pipeline.order, &schedule->super); + + if (ucc_unlikely(status != UCC_OK)) { + cl_error(team->context->lib, + "failed to init pipelined 2step bcast schedule"); + goto err_pipe_init; + } + + schedule->super.super.super.post = ucc_cl_hier_2step_bcast_start; + schedule->super.super.super.triggered_post = ucc_triggered_post; + schedule->super.super.super.finalize = + ucc_cl_hier_bcast_2step_schedule_finalize; + *task = &schedule->super.super.super; + return UCC_OK; + +err_pipe_init: + ucc_cl_hier_put_schedule(&schedule->super.super); + return status; +} diff --git a/src/components/cl/hier/cl_hier.c b/src/components/cl/hier/cl_hier.c index 0ea1381b03..676f870441 100644 --- a/src/components/cl/hier/cl_hier.c +++ b/src/components/cl/hier/cl_hier.c @@ -63,6 +63,11 @@ static ucc_config_field_t ucc_cl_hier_lib_config_table[] = { ucc_offsetof(ucc_cl_hier_lib_config_t, allreduce_rab_pipeline), UCC_CONFIG_TYPE_PIPELINE_PARAMS}, + {"BCAST_2STEP_PIPELINE", "n", + "Pipelining settings for RAB bcast algorithm", + ucc_offsetof(ucc_cl_hier_lib_config_t, bcast_2step_pipeline), + UCC_CONFIG_TYPE_PIPELINE_PARAMS}, + {NULL}}; static ucs_config_field_t ucc_cl_hier_context_config_table[] = { @@ -102,4 +107,6 @@ __attribute__((constructor)) static void cl_hier_iface_init(void) ucc_cl_hier_alltoall_algs; ucc_cl_hier.super.alg_info[ucc_ilog2(UCC_COLL_TYPE_ALLTOALLV)] = ucc_cl_hier_alltoallv_algs; + ucc_cl_hier.super.alg_info[ucc_ilog2(UCC_COLL_TYPE_BCAST)] = + ucc_cl_hier_bcast_algs; } diff --git a/src/components/cl/hier/cl_hier.h b/src/components/cl/hier/cl_hier.h index cae41cea32..8f538c1d7b 100644 --- a/src/components/cl/hier/cl_hier.h +++ b/src/components/cl/hier/cl_hier.h @@ -52,7 +52,7 @@ typedef struct ucc_cl_hier_lib_config { size_t a2av_node_thresh; ucc_pipeline_params_t allreduce_split_rail_pipeline; ucc_pipeline_params_t allreduce_rab_pipeline; - + ucc_pipeline_params_t bcast_2step_pipeline; } ucc_cl_hier_lib_config_t; typedef struct ucc_cl_hier_context_config { diff --git a/src/components/cl/hier/cl_hier_coll.c b/src/components/cl/hier/cl_hier_coll.c index e32c3687ce..acdb243ddd 100644 --- a/src/components/cl/hier/cl_hier_coll.c +++ b/src/components/cl/hier/cl_hier_coll.c @@ -12,7 +12,8 @@ const char * ucc_cl_hier_default_alg_select_str[UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR] = { - UCC_CL_HIER_ALLREDUCE_DEFAULT_ALG_SELECT_STR}; + UCC_CL_HIER_ALLREDUCE_DEFAULT_ALG_SELECT_STR, + UCC_CL_HIER_BCAST_DEFAULT_ALG_SELECT_STR}; ucc_status_t ucc_cl_hier_coll_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, @@ -27,6 +28,8 @@ ucc_status_t ucc_cl_hier_coll_init(ucc_base_coll_args_t *coll_args, return ucc_cl_hier_alltoall_init(coll_args, team, task); case UCC_COLL_TYPE_ALLTOALLV: return ucc_cl_hier_alltoallv_init(coll_args, team, task); + case UCC_COLL_TYPE_BCAST: + return ucc_cl_hier_bcast_2step_init(coll_args, team, task); default: cl_error(team->context->lib, "coll_type %s is not supported", ucc_coll_type_str(coll_args->args.coll_type)); @@ -44,6 +47,8 @@ static inline int alg_id_from_str(ucc_coll_type_t coll_type, const char *str) return ucc_cl_hier_alltoall_alg_from_str(str); case UCC_COLL_TYPE_ALLREDUCE: return ucc_cl_hier_allreduce_alg_from_str(str); + case UCC_COLL_TYPE_BCAST: + return ucc_cl_hier_bcast_alg_from_str(str); default: break; } @@ -94,6 +99,16 @@ ucc_status_t ucc_cl_hier_alg_id_to_init(int alg_id, const char *alg_id_str, break; }; break; + case UCC_COLL_TYPE_BCAST: + switch (alg_id) { + case UCC_CL_HIER_BCAST_ALG_2STEP: + *init = ucc_cl_hier_bcast_2step_init; + break; + default: + status = UCC_ERR_INVALID_PARAM; + break; + }; + break; default: status = UCC_ERR_NOT_SUPPORTED; break; diff --git a/src/components/cl/hier/cl_hier_coll.h b/src/components/cl/hier/cl_hier_coll.h index 7f5b1d4d0d..3258796675 100644 --- a/src/components/cl/hier/cl_hier_coll.h +++ b/src/components/cl/hier/cl_hier_coll.h @@ -13,8 +13,9 @@ #include "alltoallv/alltoallv.h" #include "alltoall/alltoall.h" #include "barrier/barrier.h" +#include "bcast/bcast.h" -#define UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR 1 +#define UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR 2 extern const char *ucc_cl_hier_default_alg_select_str[UCC_CL_HIER_N_DEFAULT_ALG_SELECT_STR]; diff --git a/src/components/cl/hier/cl_hier_team.c b/src/components/cl/hier/cl_hier_team.c index fad6707309..3e3b478b5d 100644 --- a/src/components/cl/hier/cl_hier_team.c +++ b/src/components/cl/hier/cl_hier_team.c @@ -159,8 +159,9 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team) ucc_cl_hier_team_t *team = ucc_derived_of(cl_team, ucc_cl_hier_team_t); ucc_cl_hier_context_t *ctx = UCC_CL_HIER_TEAM_CTX(team); ucc_status_t status = UCC_OK; - int i, j; - ucc_hier_sbgp_t *hs; + int i, j; + ucc_hier_sbgp_t *hs; + struct ucc_team_team_desc *d; if (NULL == team->team_create_req) { status = ucc_team_multiple_req_alloc(&team->team_create_req, @@ -180,9 +181,10 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team) for (j = 0; j < hs->n_tls; j++) { if (hs->tl_teams[j]) { ucc_tl_context_put(hs->tl_ctxs[j]); - team->team_create_req - ->descs[team->team_create_req->n_teams++] - .team = hs->tl_teams[j]; + d = &team->team_create_req->descs[ + team->team_create_req->n_teams++]; + d->team = hs->tl_teams[j]; + d->param.params.oob = d->team->super.params.params.oob; } } } @@ -193,6 +195,8 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team) return status; } for (i = 0; i < team->team_create_req->n_teams; i++) { + ucc_internal_oob_finalize(&team->team_create_req-> + descs[i].param.params.oob); if (team->team_create_req->descs[i].status != UCC_OK) { cl_error(ctx->super.super.lib, "tl team destroy failed (%d)", status); diff --git a/src/core/ucc_team.h b/src/core/ucc_team.h index 73b434ca91..d32eb222c5 100644 --- a/src/core/ucc_team.h +++ b/src/core/ucc_team.h @@ -105,6 +105,11 @@ static inline ucc_rank_t ucc_get_ctx_rank(ucc_team_t *team, ucc_rank_t team_rank return ucc_ep_map_eval(team->ctx_map, team_rank); } +static inline ucc_host_id_t ucc_team_rank_host_id(ucc_rank_t rank, ucc_team_t *team) +{ + return team->topo->topo->procs[ucc_get_ctx_rank(team, rank)].host_id; +} + static inline int ucc_team_ranks_on_same_node(ucc_rank_t rank1, ucc_rank_t rank2, ucc_team_t *team) { diff --git a/test/gtest/coll/test_allreduce.cc b/test/gtest/coll/test_allreduce.cc index 0b3bdd5410..c43e202d71 100644 --- a/test/gtest/coll/test_allreduce.cc +++ b/test/gtest/coll/test_allreduce.cc @@ -300,6 +300,40 @@ TYPED_TEST(test_allreduce_alg, sra_knomial_pipelined) { } } +TYPED_TEST(test_allreduce_alg, rab) { + int n_procs = 15; + ucc_job_env_t env = {{"UCC_CL_HIER_TUNE", "allreduce:@rab:0-inf:inf"}, + {"UCC_CLS", "all"}}; + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); + UccTeam_h team = job.create_team(n_procs); + int repeat = 3; + UccCollCtxVec ctxs; + std::vector mt = {UCC_MEMORY_TYPE_HOST}; + + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { + mt.push_back(UCC_MEMORY_TYPE_CUDA); + } + + for (auto count : {8, 65536, 123567}) { + for (auto inplace : {TEST_NO_INPLACE, TEST_INPLACE}) { + for (auto m : mt) { + SET_MEM_TYPE(m); + this->set_inplace(inplace); + this->data_init(n_procs, TypeParam::dt, count, ctxs, true); + UccReq req(team, ctxs); + + for (auto i = 0; i < repeat; i++) { + req.start(); + req.wait(); + EXPECT_EQ(true, this->data_validate(ctxs)); + this->reset(ctxs); + } + this->data_fini(ctxs); + } + } + } +} + template class test_allreduce_avg_order : public test_allreduce { }; diff --git a/test/gtest/coll/test_bcast.cc b/test/gtest/coll/test_bcast.cc index 31092bead6..73ccdfeae0 100644 --- a/test/gtest/coll/test_bcast.cc +++ b/test/gtest/coll/test_bcast.cc @@ -58,10 +58,16 @@ class test_bcast : public UccCollArgs, public ucc::test { for (auto r = 0; r < ctxs.size(); r++) { ucc_coll_args_t *coll = ctxs[r]->args; - size_t count = coll->dst.info.count; - ucc_datatype_t dtype = coll->dst.info.datatype; - clear_buffer(coll->dst.info.buffer, count * ucc_dt_size(dtype), - mem_type, 0); + size_t count = coll->src.info.count; + ucc_datatype_t dtype = coll->src.info.datatype; + if (r != root) { + clear_buffer(coll->src.info.buffer, count * ucc_dt_size(dtype), + mem_type, 0); + } else { + UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[r]->init_buf, + ctxs[r]->rbuf_size, mem_type, + UCC_MEMORY_TYPE_HOST)); + } } } @@ -232,3 +238,40 @@ INSTANTIATE_TEST_CASE_P( #endif ::testing::Values(1,3,65536), // count ::testing::Values(0,1))); // root + +class test_bcast_alg : public test_bcast +{}; + +UCC_TEST_F(test_bcast_alg, 2step) { + int n_procs = 15; + ucc_job_env_t env = {{"UCC_CL_HIER_TUNE", "bcast:@2step:0-inf:inf"}, + {"UCC_CLS", "all"}}; + UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); + UccTeam_h team = job.create_team(n_procs); + int repeat = 1; + UccCollCtxVec ctxs; + std::vector mt = {UCC_MEMORY_TYPE_HOST}; + + if (UCC_OK == ucc_mc_available(UCC_MEMORY_TYPE_CUDA)) { + mt.push_back(UCC_MEMORY_TYPE_CUDA); + } + + for (auto count : {8, 65536}) { + for (int root = 0; root < n_procs; root++) { + for (auto m : mt) { + this->set_root(root); + SET_MEM_TYPE(m); + this->data_init(n_procs, UCC_DT_INT8, count, ctxs, false); + UccReq req(team, ctxs); + + for (auto i = 0; i < repeat; i++) { + req.start(); + req.wait(); + EXPECT_EQ(true, this->data_validate(ctxs)); + this->reset(ctxs); + } + this->data_fini(ctxs); + } + } + } +}