Skip to content

Commit

Permalink
TL/UCP: add ranks reordering to allgatherv
Browse files Browse the repository at this point in the history
  • Loading branch information
shimmybalsam committed Aug 8, 2023
1 parent 6bdb758 commit 11b16cf
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
9 changes: 1 addition & 8 deletions src/components/tl/ucp/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
#include "allgatherv.h"
#include "utils/ucc_coll_utils.h"

ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *task);

void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *task);

ucc_base_coll_alg_info_t
ucc_tl_ucp_allgatherv_algs[UCC_TL_UCP_ALLGATHERV_ALG_LAST + 1] = {
[UCC_TL_UCP_ALLGATHERV_ALG_RING] =
Expand All @@ -29,8 +25,5 @@ ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task)
return UCC_ERR_NOT_SUPPORTED;
}

task->super.post = ucc_tl_ucp_allgatherv_ring_start;
task->super.progress = ucc_tl_ucp_allgatherv_ring_progress;

return UCC_OK;
return ucc_tl_ucp_allgatherv_ring_init_common(task);
}
7 changes: 6 additions & 1 deletion src/components/tl/ucp/allgatherv/allgatherv.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ enum {
extern ucc_base_coll_alg_info_t
ucc_tl_ucp_allgatherv_algs[UCC_TL_UCP_ALLGATHERV_ALG_LAST + 1];

ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task);
ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *task);

void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task);
#endif
45 changes: 37 additions & 8 deletions src/components/tl/ucp/allgatherv/allgatherv_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer;
ucc_memory_type_t rmem = args->dst.info_v.mem_type;
size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype);
ucc_rank_t sendto = (grank + 1) % gsize;
ucc_rank_t recvfrom = (grank - 1 + gsize) % gsize;
ucc_rank_t send_idx, recv_idx;
ucc_rank_t send_idx, recv_idx, sendto, recvfrom;
size_t data_size, data_displ;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}
while (task->tagged.send_posted < gsize) {
send_idx = (grank - task->tagged.send_posted + 1 + gsize) % gsize;

sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);

while (task->tagged.send_posted < tsize) {
send_idx =
ucc_ep_map_eval(task->subset.map,(trank -
task->tagged.send_posted + 1 +
tsize) % tsize);
data_displ = ucc_coll_args_get_displacement(
args, args->dst.info_v.displacements, send_idx) *
rdt_size;
Expand All @@ -41,7 +46,9 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task)
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(rbuf + data_displ), data_size,
rmem, sendto, team, task),
task, out);
recv_idx = (grank - task->tagged.recv_posted + gsize) % gsize;
recv_idx =
ucc_ep_map_eval(task->subset.map,(trank -
task->tagged.recv_posted + tsize) % tsize);
data_displ = ucc_coll_args_get_displacement(
args, args->dst.info_v.displacements, recv_idx) *
rdt_size;
Expand Down Expand Up @@ -98,3 +105,25 @@ ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task)
error:
return task->super.status;
}

ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task)
{
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_sbgp_t *sbgp;

if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
return UCC_ERR_NOT_SUPPORTED;
}

if (team->cfg.use_reordering) {
sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED);
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}

task->super.post = ucc_tl_ucp_allgatherv_ring_start;
task->super.progress = ucc_tl_ucp_allgatherv_ring_progress;

return UCC_OK;
}

0 comments on commit 11b16cf

Please sign in to comment.