Skip to content

Commit

Permalink
TL/UCP: use knomial pattern in gather (#1044)
Browse files Browse the repository at this point in the history
* TL/UCP: use knomial pattern in gather

* REVIEW: fix review comments
  • Loading branch information
Sergei-Lebedev authored Dec 13, 2024
1 parent 4f67436 commit 73651ea
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 177 deletions.
41 changes: 35 additions & 6 deletions src/coll_patterns/recursive_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum {
KN_PATTERN_ALLGATHER,
KN_PATTERN_ALLGATHERV,
KN_PATTERN_ALLGATHERX,
KN_PATTERN_GATHER,
KN_PATTERN_GATHERX,
};

typedef struct ucc_knomial_pattern {
Expand Down Expand Up @@ -83,7 +85,7 @@ static inline ucc_rank_t ucc_kn_pattern_radix_pow_init(ucc_knomial_pattern_t *p,
static inline void
ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix, ucc_knomial_pattern_t *p,
int backward)
int backward, int has_extra)
{
ucc_rank_t fs = radix;
ucc_rank_t n_full_subtrees;
Expand All @@ -100,7 +102,7 @@ ucc_knomial_pattern_init_impl(ucc_rank_t size, ucc_rank_t rank,
p->backward = backward;
p->iteration = 0;
n_full_subtrees = ucc_kn_pattern_n_full(p);
p->n_extra = size - n_full_subtrees * p->full_pow_size;
p->n_extra = has_extra ? size - n_full_subtrees * p->full_pow_size : 0;
p->n_iters = (p->n_extra && n_full_subtrees == 1) ?
p->pow_radix_sup - 1 : p->pow_radix_sup;
p->radix_pow = ucc_kn_pattern_radix_pow_init(p, backward);
Expand All @@ -115,14 +117,22 @@ ucc_knomial_pattern_init_backward(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 1);
ucc_knomial_pattern_init_impl(size, rank, radix, p, 1, 1);
}

static inline void
ucc_knomial_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0);
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 1);
}

static inline void
ucc_knomial_pattern_init_no_extra(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_impl(size, rank, radix, p, 0, 0);
}

static inline ucc_rank_t
Expand Down Expand Up @@ -186,6 +196,23 @@ ucc_knomial_pattern_get_loop_peer(ucc_knomial_pattern_t *p, ucc_rank_t rank,
ucc_knomial_pattern_loop_rank_inv(p, peer);
}

static inline ucc_rank_t
ucc_knomial_pattern_get_base_rank(ucc_knomial_pattern_t *p, ucc_rank_t rank)
{
ucc_rank_t step_size = p->radix_pow * p->radix;
ucc_rank_t lrank;
ucc_kn_radix_t s;

lrank = ucc_knomial_pattern_loop_rank(p, rank);
s = ucc_div_round_up(step_size - (lrank % step_size), p->radix_pow);

if (s == p->radix) {
return rank;
} else {
return ucc_knomial_pattern_get_loop_peer(p, rank, s);
}
}

static inline void
ucc_knomial_pattern_next_iteration(ucc_knomial_pattern_t *p)
{
Expand Down Expand Up @@ -224,11 +251,13 @@ static inline ucc_rank_t
ucc_knomial_calc_recv_dist(ucc_rank_t team_size, ucc_rank_t rank,
ucc_rank_t radix, ucc_rank_t root)
{
ucc_rank_t root_base = 0;
ucc_rank_t dist = 1;

if (rank == root) {
return 0;
}
ucc_rank_t root_base = 0 ;
ucc_rank_t dist = 1;

while (dist <= team_size) {
if (rank < root_base + radix * dist) {
break;
Expand Down
76 changes: 76 additions & 0 deletions src/coll_patterns/sra_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,82 @@ ucc_knx_block(ucc_rank_t rank, ucc_rank_t size, ucc_kn_radix_t radix,
*b_offset = offset;
}

static inline void
ucc_kn_g_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_no_extra(size, rank, radix, p);
p->type = KN_PATTERN_GATHER;
p->count = count;
p->block_size = p->radix_pow * radix;
p->block_offset = ucc_knomial_pattern_loop_rank(p, rank) / p->block_size *
p->block_size;
}

static inline void
ucc_kn_gx_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init_backward(size, rank, radix, p);
p->type = KN_PATTERN_GATHERX;
p->count = count;
if (p->node_type != KN_NODE_EXTRA) {
p->block_size = ucc_kn_compute_step_radix(p);
ucc_knx_block(rank, size, radix, count, p->n_iters - 1,
&p->block_size_counts, &p->block_offset);

}

}

static inline void
ucc_kn_g_pattern_peer_seg(ucc_rank_t peer, ucc_knomial_pattern_t *p,
size_t *seg_count, ptrdiff_t *seg_offset)
{
ucc_rank_t step_radix, seg_index;

*seg_count = 0;
*seg_offset = 0;
switch (p->type) {
case KN_PATTERN_GATHER:
*seg_count = ucc_min(p->radix_pow, p->size - peer) * (p->count / p->size);
*seg_offset = peer * (p->count / p->size);
return;
case KN_PATTERN_GATHERX:
step_radix = ucc_kn_compute_step_radix(p);
seg_index = ucc_kn_compute_seg_index(peer, p->radix_pow, p);
*seg_offset = ucc_buffer_block_offset(p->block_size_counts, step_radix,
seg_index) + p->block_offset;
*seg_count = ucc_buffer_block_count(p->block_size_counts, step_radix,
seg_index);
return;
default:
ucc_assert(0);
}
}

static inline void ucc_kn_g_pattern_next_iter(ucc_knomial_pattern_t *p)
{
ucc_rank_t rank;
if (p->type == KN_PATTERN_GATHERX) {
ucc_knomial_pattern_next_iteration_backward(p);

if (!ucc_knomial_pattern_loop_done(p)) {
ucc_knx_block(p->rank, p->size, p->radix, p->count,
p->n_iters - 1 - p->iteration,
&p->block_size_counts, &p->block_offset);
}
} else {
rank = ucc_knomial_pattern_loop_rank(p, p->rank);
ucc_knomial_pattern_next_iteration(p);

if (!ucc_knomial_pattern_loop_done(p)) {
p->block_size *= ucc_kn_compute_step_radix(p);
p->block_offset = rank / p->block_size * p->block_size;
}
}
}

static inline void
ucc_kn_ag_pattern_init(ucc_rank_t size, ucc_rank_t rank, ucc_kn_radix_t radix,
size_t count, ucc_knomial_pattern_t *p)
Expand Down
61 changes: 6 additions & 55 deletions src/components/tl/ucp/gather/gather.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -17,62 +17,13 @@ ucc_base_coll_alg_info_t
[UCC_TL_UCP_GATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

static inline uint32_t calc_buffer_size(ucc_rank_t rank, uint32_t radix, ucc_rank_t team_size)
{
uint32_t radix_valuation;

if (rank == 0) {
return team_size;
}
radix_valuation = calc_valuation(rank, radix);
return (uint32_t)ucc_min(pow(radix, radix_valuation), team_size - rank);
}

ucc_status_t ucc_tl_ucp_gather_init(ucc_tl_ucp_task_t *task)
{
ucc_coll_args_t * args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t myrank = UCC_TL_TEAM_RANK(team);
ucc_rank_t team_size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t root = args->root;
ucc_rank_t vrank = (myrank - root + team_size) % team_size;
ucc_status_t status = UCC_OK;
ucc_memory_type_t mtype;
ucc_datatype_t dt;
size_t count, data_size;
uint32_t buffer_size;
int isleaf;

if (root == myrank) {
count = args->dst.info.count;
dt = args->dst.info.datatype;
mtype = args->dst.info.mem_type;
} else {
count = args->src.info.count;
dt = args->src.info.datatype;
mtype = args->src.info.mem_type;
}
data_size = count * ucc_dt_size(dt);
task->super.post = ucc_tl_ucp_gather_knomial_start;
task->super.progress = ucc_tl_ucp_gather_knomial_progress;
task->super.finalize = ucc_tl_ucp_gather_knomial_finalize;
task->gather_kn.radix =
ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, team_size);
CALC_KN_TREE_DIST(team_size, task->gather_kn.radix,
task->gather_kn.max_dist);
isleaf = (vrank % task->gather_kn.radix != 0 || vrank == team_size - 1);
task->gather_kn.scratch_mc_header = NULL;
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_kn_radix_t radix;

if (vrank == 0) {
task->gather_kn.scratch = args->dst.info.buffer;
} else if (isleaf) {
task->gather_kn.scratch = args->src.info.buffer;
} else {
buffer_size = calc_buffer_size(vrank, task->gather_kn.radix, team_size);
status = ucc_mc_alloc(&task->gather_kn.scratch_mc_header,
buffer_size * data_size, mtype);
task->gather_kn.scratch = task->gather_kn.scratch_mc_header->addr;
}
radix = ucc_min(UCC_TL_UCP_TEAM_LIB(team)->cfg.gather_kn_radix, size);

return status;
return ucc_tl_ucp_gather_knomial_init_common(task, radix);
}
10 changes: 9 additions & 1 deletion src/components/tl/ucp/gather/gather.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -45,4 +45,12 @@ void ucc_tl_ucp_gather_knomial_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_gather_knomial_finalize(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_gather_knomial_init_common(ucc_tl_ucp_task_t *task,
ucc_kn_radix_t radix);

/* Internal interface with custom radix */
ucc_status_t ucc_tl_ucp_gather_knomial_init_r(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h,
ucc_kn_radix_t radix);
#endif
Loading

0 comments on commit 73651ea

Please sign in to comment.