Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: add ranks reordering to allgatherv #819

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/components/tl/ucp/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
#include "allgatherv.h"
#include "utils/ucc_coll_utils.h"

ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *task);

void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *task);

ucc_base_coll_alg_info_t
ucc_tl_ucp_allgatherv_algs[UCC_TL_UCP_ALLGATHERV_ALG_LAST + 1] = {
[UCC_TL_UCP_ALLGATHERV_ALG_RING] =
Expand All @@ -29,8 +25,5 @@ ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task)
return UCC_ERR_NOT_SUPPORTED;
}

task->super.post = ucc_tl_ucp_allgatherv_ring_start;
task->super.progress = ucc_tl_ucp_allgatherv_ring_progress;

return UCC_OK;
return ucc_tl_ucp_allgatherv_ring_init_common(task);
}
3 changes: 2 additions & 1 deletion src/components/tl/ucp/allgatherv/allgatherv.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ enum {
extern ucc_base_coll_alg_info_t
ucc_tl_ucp_allgatherv_algs[UCC_TL_UCP_ALLGATHERV_ALG_LAST + 1];

ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task);
ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task);

ucc_status_t ucc_tl_ucp_allgatherv_init(ucc_tl_ucp_task_t *task);
#endif
46 changes: 38 additions & 8 deletions src/components/tl/ucp/allgatherv/allgatherv_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t grank = UCC_TL_TEAM_RANK(team);
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
ptrdiff_t rbuf = (ptrdiff_t)args->dst.info_v.buffer;
ucc_memory_type_t rmem = args->dst.info_v.mem_type;
size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype);
ucc_rank_t sendto = (grank + 1) % gsize;
ucc_rank_t recvfrom = (grank - 1 + gsize) % gsize;
ucc_rank_t send_idx, recv_idx;
ucc_rank_t send_idx, recv_idx, sendto, recvfrom;
size_t data_size, data_displ;

if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}
while (task->tagged.send_posted < gsize) {
send_idx = (grank - task->tagged.send_posted + 1 + gsize) % gsize;

sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);

while (task->tagged.send_posted < tsize) {
send_idx =
ucc_ep_map_eval(task->subset.map, (trank -
task->tagged.send_posted + 1 +
tsize) % tsize);
data_displ = ucc_coll_args_get_displacement(
args, args->dst.info_v.displacements, send_idx) *
rdt_size;
Expand All @@ -41,7 +46,10 @@ void ucc_tl_ucp_allgatherv_ring_progress(ucc_coll_task_t *coll_task)
UCPCHECK_GOTO(ucc_tl_ucp_send_nb((void *)(rbuf + data_displ), data_size,
rmem, sendto, team, task),
task, out);
recv_idx = (grank - task->tagged.recv_posted + gsize) % gsize;
recv_idx =
ucc_ep_map_eval(task->subset.map, (trank -
task->tagged.recv_posted +
tsize) % tsize);
data_displ = ucc_coll_args_get_displacement(
args, args->dst.info_v.displacements, recv_idx) *
rdt_size;
Expand Down Expand Up @@ -98,3 +106,25 @@ ucc_status_t ucc_tl_ucp_allgatherv_ring_start(ucc_coll_task_t *coll_task)
error:
return task->super.status;
}

ucc_status_t ucc_tl_ucp_allgatherv_ring_init_common(ucc_tl_ucp_task_t *task)
{
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_sbgp_t *sbgp;

if (!ucc_coll_args_is_predefined_dt(&TASK_ARGS(task), UCC_RANK_INVALID)) {
tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported");
return UCC_ERR_NOT_SUPPORTED;
}

if (team->cfg.use_reordering) {
sbgp = ucc_topo_get_sbgp(team->topo, UCC_SBGP_FULL_HOST_ORDERED);
task->subset.myrank = sbgp->group_rank;
task->subset.map = sbgp->map;
}

task->super.post = ucc_tl_ucp_allgatherv_ring_start;
task->super.progress = ucc_tl_ucp_allgatherv_ring_progress;

return UCC_OK;
}