Skip to content

Commit

Permalink
TL/UCP: remove memcpy in last SRA step (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev authored Aug 11, 2023
1 parent 6bdb758 commit cfad103
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 96 deletions.
6 changes: 6 additions & 0 deletions src/coll_patterns/recursive_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ ucc_knomial_pattern_next_iteration(ucc_knomial_pattern_t *p)
p->radix_pow *= p->radix;
}

static inline void ucc_knomial_pattern_prev_iteration(ucc_knomial_pattern_t *p)
{
p->iteration--;
p->radix_pow /= p->radix;
}

static inline void
ucc_knomial_pattern_next_iteration_backward(ucc_knomial_pattern_t *p)
{
Expand Down
59 changes: 59 additions & 0 deletions src/coll_patterns/sra_knomial.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,63 @@ static inline void ucc_kn_ag_pattern_next_iter(ucc_knomial_pattern_t *p)
}
}

static inline void ucc_kn_rsx_pattern_init(ucc_rank_t size, ucc_rank_t rank,
ucc_kn_radix_t radix, size_t count,
ucc_knomial_pattern_t *p)
{
ucc_knomial_pattern_init(size, rank, radix, p);
p->type = KN_PATTERN_REDUCE_SCATTERX;
p->count = count;
p->block_size_counts = count;
p->block_size = size - p->n_extra;
}

static inline void
ucc_kn_rs_pattern_peer_seg(ucc_rank_t peer, ucc_knomial_pattern_t *p,
size_t *peer_seg_count, size_t *peer_seg_offset)
{
ucc_rank_t step_radix, seg_index;

*peer_seg_count = 0;
*peer_seg_offset = 0;

switch (p->type) {
case KN_PATTERN_REDUCE_SCATTERX:
step_radix = ucc_kn_compute_step_radix(p);
seg_index = ucc_kn_compute_seg_index(peer, p->radix_pow, p);
*peer_seg_offset = ucc_buffer_block_offset(p->block_size_counts,
step_radix, seg_index);
*peer_seg_count = ucc_buffer_block_count(p->block_size_counts,
step_radix, seg_index);
return;
case KN_PATTERN_REDUCE_SCATTER:
case KN_PATTERN_REDUCE_SCATTERV:
/* not implemented */
ucc_assert(0);
default:
ucc_assert(0);
}
}

static inline void ucc_kn_rs_pattern_next_iter(ucc_knomial_pattern_t *p)
{
size_t offset, bs;

ucc_kn_rs_pattern_peer_seg(p->rank, p, &bs, &offset);
p->block_size_counts = bs;

switch (p->type) {
case KN_PATTERN_REDUCE_SCATTERX:
p->block_offset += offset;
ucc_knomial_pattern_next_iteration(p);
return;
case KN_PATTERN_REDUCE_SCATTER:
case KN_PATTERN_REDUCE_SCATTERV:
/* not implemented */
ucc_assert(0);
default:
ucc_assert(0);
}
}

#endif
203 changes: 109 additions & 94 deletions src/components/tl/ucp/reduce_scatter/reduce_scatter_knomial.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -15,6 +15,23 @@
task->reduce_scatter_kn.phase = _phase; \
} while (0)

static inline void get_sbuf_rbuf(ucc_knomial_pattern_t *p, ucc_coll_args_t *args,
void *scratch, size_t block_count,
void **sbuf, void **rbuf)
{
uint8_t node_type = p->node_type;
size_t dt_size = ucc_dt_size(args->dst.info.datatype);

if (ucc_knomial_pattern_loop_first_iteration(p)) {
*sbuf = (KN_NODE_PROXY == node_type || UCC_IS_INPLACE(*args))
? args->dst.info.buffer: args->src.info.buffer;
*rbuf = scratch;
} else {
*sbuf = scratch;
*rbuf = PTR_OFFSET(*sbuf, block_count * dt_size);
}
}

void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
Expand All @@ -23,7 +40,7 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_kn_radix_t radix = task->reduce_scatter_kn.p.radix;
int avg_pre_op =
UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op;
UCC_TL_UCP_TEAM_LIB(team)->cfg.reduce_avg_pre_op;
uint8_t node_type =
task->reduce_scatter_kn.p.node_type;
ucc_knomial_pattern_t *p = &task->reduce_scatter_kn.p;
Expand All @@ -33,24 +50,22 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
size_t count = args->dst.info.count;
ucc_datatype_t dt = args->dst.info.datatype;
void *sbuf = UCC_IS_INPLACE(*args) ?
rbuf : args->src.info.buffer;
rbuf : args->src.info.buffer;
size_t dt_size = ucc_dt_size(dt);
size_t data_size = count * dt_size;
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_ee_executor_task_args_t eargs = {0};
ptrdiff_t peer_seg_offset, local_seg_offset, offset;
ucc_rank_t peer, step_radix, peer_seg_index, local_seg_index;
ucc_status_t status;
ptrdiff_t peer_seg_offset, local_seg_offset, offset;
ucc_rank_t peer, step_radix, local_seg_index;
ucc_status_t status;
ucc_kn_radix_t loop_step;
size_t block_count, peer_seg_count, local_seg_count;
void *reduce_data, *local_data;
int is_avg;
size_t block_count, peer_seg_count, local_seg_count;
void *reduce_data, *local_data;
int is_avg;

local_seg_count = 0;
block_count = ucc_sra_kn_compute_block_count(count, rank, p);
UCC_KN_REDUCE_GOTO_PHASE(task->reduce_scatter_kn.phase);

if (KN_NODE_EXTRA == node_type) {
peer = ucc_knomial_pattern_get_proxy(p, rank);
UCPCHECK_GOTO(
Expand All @@ -73,105 +88,87 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
}
if (KN_NODE_EXTRA == node_type) {
goto out;
} else {
status = ucc_dt_reduce(sbuf, scratch, rbuf, count, dt, args, 0, 0,
task->reduce_scatter_kn.executor,
&task->reduce_scatter_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction");
task->super.status = status;
return;
}
UCC_KN_PHASE_EXTRA_REDUCE:
EXEC_TASK_TEST(UCC_KN_PHASE_EXTRA_REDUCE,
"failed to perform dt reduction",
task->reduce_scatter_kn.etask);
}
status = ucc_dt_reduce(sbuf, scratch, rbuf, count, dt, args, 0, 0,
task->reduce_scatter_kn.executor,
&task->reduce_scatter_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction");
task->super.status = status;
return;
}
UCC_KN_PHASE_EXTRA_REDUCE:
EXEC_TASK_TEST(UCC_KN_PHASE_EXTRA_REDUCE,
"failed to perform dt reduction",
task->reduce_scatter_kn.etask);

}
while (!ucc_knomial_pattern_loop_done(p)) {
step_radix = ucc_kn_compute_step_radix(p);
block_count = ucc_sra_kn_compute_block_count(count, rank, p);
sbuf = (ucc_knomial_pattern_loop_first_iteration(p))
? ((KN_NODE_PROXY == node_type || UCC_IS_INPLACE(*args))
? args->dst.info.buffer
: args->src.info.buffer)
: task->reduce_scatter_kn.scratch;
get_sbuf_rbuf(p, args, task->reduce_scatter_kn.scratch, block_count,
&sbuf, &rbuf);
ucc_kn_rs_pattern_peer_seg(rank, p, &local_seg_count,
&local_seg_offset);
for (loop_step = radix - 1; loop_step > 0; loop_step--) {
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
if (peer == UCC_KN_PEER_NULL)
if (peer == UCC_KN_PEER_NULL) {
continue;

peer_seg_index =
ucc_kn_compute_seg_index(peer, p->radix_pow, p);
peer_seg_count = ucc_sra_kn_compute_seg_size(
block_count, step_radix, peer_seg_index);
peer_seg_offset = ucc_sra_kn_compute_seg_offset(
block_count, step_radix, peer_seg_index);
}
ucc_kn_rs_pattern_peer_seg(peer, p, &peer_seg_count,
&peer_seg_offset);
UCPCHECK_GOTO(
ucc_tl_ucp_send_nb(PTR_OFFSET(sbuf, peer_seg_offset * dt_size),
peer_seg_count * dt_size, mem_type, peer,
team, task),
task, out);
}

local_seg_index = ucc_kn_compute_seg_index(rank, p->radix_pow, p);
local_seg_count = ucc_sra_kn_compute_seg_size(block_count, step_radix,
local_seg_index);

rbuf = task->reduce_scatter_kn.scratch;
if (!ucc_knomial_pattern_loop_first_iteration(p)) {
rbuf = PTR_OFFSET(rbuf, block_count * dt_size);
}
for (loop_step = 1; loop_step < radix; loop_step++) {
peer = ucc_knomial_pattern_get_loop_peer(p, rank, loop_step);
if (peer == UCC_KN_PEER_NULL)
continue;
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(rbuf, local_seg_count * dt_size,
mem_type, peer, team, task),
task, out);
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb(rbuf, local_seg_count * dt_size, mem_type,
peer, team, task),
task, out);
rbuf = PTR_OFFSET(rbuf, local_seg_count * dt_size);
}
UCC_KN_PHASE_LOOP:
UCC_KN_PHASE_LOOP:
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
SAVE_STATE(UCC_KN_PHASE_LOOP);
return;
}
if (task->tagged.send_posted > p->iteration * (radix - 1)) {
sbuf = (ucc_knomial_pattern_loop_first_iteration(p))
? ((KN_NODE_PROXY == node_type || UCC_IS_INPLACE(*args))
? args->dst.info.buffer
: args->src.info.buffer)
: task->reduce_scatter_kn.scratch;
rbuf = (!ucc_knomial_pattern_loop_first_iteration(p))
? PTR_OFFSET(task->reduce_scatter_kn.scratch,
block_count * dt_size)
: task->reduce_scatter_kn.scratch;
step_radix = ucc_kn_compute_step_radix(p);
local_seg_index =
ucc_kn_compute_seg_index(rank, p->radix_pow, p);
local_seg_count = ucc_sra_kn_compute_seg_size(
block_count, step_radix, local_seg_index);
local_seg_offset = ucc_sra_kn_compute_seg_offset(
block_count, step_radix, local_seg_index);
ucc_kn_rs_pattern_peer_seg(rank, p, &local_seg_count,
&local_seg_offset);
get_sbuf_rbuf(p, args, task->reduce_scatter_kn.scratch, block_count,
&sbuf, &rbuf);
local_data = PTR_OFFSET(sbuf, local_seg_offset * dt_size);
reduce_data = task->reduce_scatter_kn.scratch;
is_avg = args->op == UCC_OP_AVG &&
(avg_pre_op ? ucc_knomial_pattern_loop_first_iteration(p)
: ucc_knomial_pattern_loop_last_iteration(p));
ucc_assert((step_radix - 1) ==
(task->tagged.send_posted - p->iteration * (radix - 1)));

if (task->reduce_scatter_kn.scratch_mc_header &&
if (!task->reduce_scatter_kn.scratch_mc_header &&
ucc_knomial_pattern_loop_last_iteration(p)) {
ucc_sra_kn_get_offset_and_seglen(count, dt_size, rank, size, radix,
&offset, &local_seg_count);
reduce_data = PTR_OFFSET(args->dst.info.buffer, offset);
status = ucc_dt_reduce_strided(
rbuf, PTR_OFFSET(rbuf, local_seg_count * dt_size), rbuf,
step_radix - 2, local_seg_count, local_seg_count * dt_size,
dt, args, 0, 0, task->reduce_scatter_kn.executor,
&task->reduce_scatter_kn.etask);

} else {
if (task->reduce_scatter_kn.scratch_mc_header &&
ucc_knomial_pattern_loop_last_iteration(p)) {
ucc_sra_kn_get_offset_and_seglen(count, dt_size, rank, size,
radix, &offset,
&local_seg_count);
reduce_data = PTR_OFFSET(args->dst.info.buffer, offset);
}
status = ucc_dt_reduce_strided(
local_data, rbuf, reduce_data, step_radix - 1,
local_seg_count, local_seg_count * dt_size, dt, args,
is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0,
AVG_ALPHA(task), task->reduce_scatter_kn.executor,
&task->reduce_scatter_kn.etask);
}
status = ucc_dt_reduce_strided(
local_data, rbuf, reduce_data,
task->tagged.send_posted - p->iteration * (radix - 1),
local_seg_count, local_seg_count * dt_size, dt, args,
is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0,
AVG_ALPHA(task), task->reduce_scatter_kn.executor,
&task->reduce_scatter_kn.etask);
if (ucc_unlikely(UCC_OK != status)) {
tl_error(UCC_TASK_LIB(task), "failed to perform dt reduction");
task->super.status = status;
Expand All @@ -182,25 +179,42 @@ void ucc_tl_ucp_reduce_scatter_knomial_progress(ucc_coll_task_t *coll_task)
"failed to perform dt reduction",
task->reduce_scatter_kn.etask);
}
ucc_knomial_pattern_next_iteration(p);
ucc_kn_rs_pattern_next_iter(p);
}

if (!task->reduce_scatter_kn.scratch_mc_header) {
ucc_knomial_pattern_prev_iteration(p);
get_sbuf_rbuf(p, args, task->reduce_scatter_kn.scratch, block_count,
&sbuf, &rbuf);

step_radix = ucc_kn_compute_step_radix(p);
local_seg_index = ucc_kn_compute_seg_index(
rank, p->radix_pow, p);
local_seg_count = ucc_sra_kn_compute_seg_size(
block_count, step_radix, local_seg_index);
local_seg_offset = ucc_sra_kn_compute_seg_offset(
block_count, step_radix, local_seg_index);
local_data = PTR_OFFSET(sbuf, local_seg_offset * dt_size);
is_avg = args->op == UCC_OP_AVG &&
(avg_pre_op ? ucc_knomial_pattern_loop_first_iteration(p)
: ucc_knomial_pattern_loop_last_iteration(p));

ucc_sra_kn_get_offset_and_seglen(count, dt_size, rank, size, radix,
&offset, &local_seg_count);
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset);
eargs.copy.src = task->reduce_scatter_kn.scratch;
eargs.copy.len = local_seg_count * dt_size;
status = ucc_ee_executor_task_post(task->reduce_scatter_kn.executor, &eargs,
&task->reduce_scatter_kn.etask);
status = ucc_dt_reduce(local_data, rbuf,
PTR_OFFSET(args->dst.info.buffer, offset),
local_seg_count, dt, args,
is_avg ? UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA : 0,
AVG_ALPHA(task),
task->reduce_scatter_kn.executor,
&task->reduce_scatter_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(UCC_TASK_LIB(task), "failed to copy data to dst buffer");
tl_error(UCC_TASK_LIB(task), "failed to reduce data to dst buffer");
task->super.status = status;
return;
}
UCC_KN_PHASE_COMPLETE:
EXEC_TASK_TEST(UCC_KN_PHASE_COMPLETE, "failed to perform memcpy",
EXEC_TASK_TEST(UCC_KN_PHASE_COMPLETE, "failed to perform reduce",
task->reduce_scatter_kn.etask);
}
UCC_KN_PHASE_PROXY: /* unused label */
Expand All @@ -222,13 +236,14 @@ ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_start(ucc_coll_task_t *coll_task)
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_reduce_scatter_kn_start",
0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
ucc_knomial_pattern_init(size, rank, task->reduce_scatter_kn.p.radix,
&task->reduce_scatter_kn.p);
ucc_kn_rsx_pattern_init(size, rank, task->reduce_scatter_kn.p.radix,
args->dst.info.count, &task->reduce_scatter_kn.p);
if (!task->reduce_scatter_kn.scratch_mc_header) {
task->reduce_scatter_kn.scratch = args->dst.info.buffer;
}
task->reduce_scatter_kn.phase = UCC_KN_PHASE_INIT;
status = ucc_coll_task_get_executor(&task->super,

status = ucc_coll_task_get_executor(&task->super,
&task->reduce_scatter_kn.executor);
if (ucc_unlikely(status != UCC_OK)) {
return status;
Expand Down
2 changes: 1 addition & 1 deletion src/utils/ucc_dt_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ucc_dt_reduce_strided(void *src1, void *src2, void *dst, size_t n_vectors,
{
ucc_ee_executor_task_args_t eargs;

if (count == 0) {
if (count == 0 || n_vectors == 0) {
*task = NULL;
return UCC_OK;
}
Expand Down
2 changes: 1 addition & 1 deletion test/mpi/test_case.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::shared_ptr<TestCase> TestCase::init_single(ucc_test_team_t &_team,
void TestCase::run(bool triggered)
{
if (triggered) {
ucc_ee_h ee;
ucc_ee_h ee = nullptr;
ucc_ev_t comp_ev, *post_ev;
ucc_ee_type_t ee_type;

Expand Down

0 comments on commit cfad103

Please sign in to comment.