diff --git a/src/components/tl/ucp/allgatherv/allgatherv.c b/src/components/tl/ucp/allgatherv/allgatherv.c index a193cb9695..39fbc5472d 100644 --- a/src/components/tl/ucp/allgatherv/allgatherv.c +++ b/src/components/tl/ucp/allgatherv/allgatherv.c @@ -9,10 +9,6 @@ #include "allgatherv.h" #include "utils/ucc_coll_utils.h" -ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *task); - -void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *task); - ucc_base_coll_alg_info_t ucc_tl_ucp_allgatherv_algs[UCC_TL_UCP_ALLGATHERV_ALG_LAST + 1] = { [UCC_TL_UCP_ALLGATHERV_ALG_RING] = @@ -29,8 +25,5 @@ ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task) return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_ucp_allgatherv_ring_start; - task->super.progress = ucc_tl_ucp_allgatherv_ring_progress; - - return UCC_OK; + return ucc_tl_ucp_allgatherv_ring_init_common(task); } diff --git a/src/components/tl/ucp/allgatherv/allgatherv.h b/src/components/tl/ucp/allgatherv/allgatherv.h index 4790f9cdc0..e9faf27ed1 100644 --- a/src/components/tl/ucp/allgatherv/allgatherv.h +++ b/src/components/tl/ucp/allgatherv/allgatherv.h @@ -18,6 +18,7 @@ enum { extern ucc_base_coll_alg_info_t ucc_tl_ucp_allgatherv_algs[UCC_TL_UCP_ALLGATHERV_ALG_LAST + 1]; -ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task); +ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task); +ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task); #endif diff --git a/src/components/tl/ucp/allgatherv/allgatherv_ring.c b/src/components/tl/ucp/allgatherv/allgatherv_ring.c index 3a0e607f53..efc3a06099 100644 --- a/src/components/tl/ucp/allgatherv/allgatherv_ring.c +++ b/src/components/tl/ucp/allgatherv/allgatherv_ring.c @@ -17,21 +17,26 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task) ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_coll_args_t *args = &TASK_ARGS(task); ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t grank = UCC_TL_TEAM_RANK(team); - ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team); + ucc_rank_t trank = task->subset.myrank; + ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer; ucc_memory_type_t rmem = args->dst.info_v.mem_type; size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype); - ucc_rank_t sendto = (grank + 1) % gsize; - ucc_rank_t recvfrom = (grank - 1 + gsize) % gsize; - ucc_rank_t send_idx, recv_idx; + ucc_rank_t send_idx, recv_idx, sendto, recvfrom; size_t data_size, data_displ; if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } - while (task->tagged.send_posted < gsize) { - send_idx = (grank - task->tagged.send_posted + 1 + gsize) % gsize; + + sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); + recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); + + while (task->tagged.send_posted < tsize) { + send_idx = + ucc_ep_map_eval(task->subset.map, (trank - + task->tagged.send_posted + 1 + + tsize) % tsize); data_displ = ucc_coll_args_get_displacement( args, args->dst.info_v.displacements, send_idx) * rdt_size; @@ -41,7 +46,10 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task) UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(rbuf + data_displ), data_size, rmem, sendto, team, task), task, out); - recv_idx = (grank - task->tagged.recv_posted + gsize) % gsize; + recv_idx = + ucc_ep_map_eval(task->subset.map, (trank - + task->tagged.recv_posted + + tsize) % tsize); data_displ = ucc_coll_args_get_displacement( args, args->dst.info_v.displacements, recv_idx) * rdt_size; @@ -98,3 +106,25 @@ ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task) error: return task->super.status; } + +ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task) +{ + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_sbgp_t *sbgp; + + if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) { + tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); + return UCC_ERR_NOT_SUPPORTED; + } + + if (team->cfg.use_reordering) { + sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED); + task->subset.myrank = sbgp->group_rank; + task->subset.map = sbgp->map; + } + + task->super.post = ucc_tl_ucp_allgatherv_ring_start; + task->super.progress = ucc_tl_ucp_allgatherv_ring_progress; + + return UCC_OK; +}