Skip to content

Commit

Permalink
V0.1.x: Adjust allgather, alltoall count (#281)
Browse files Browse the repository at this point in the history
* TL/UCP: Adjust allgather, alltoall count

* TEST: build fix

Co-authored-by: Lior Paz <liorpa@mellanox.com>
  • Loading branch information
valentin petrov and Lior Paz authored Jul 30, 2021
1 parent 05ef08d commit ca30502
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 127 deletions.
10 changes: 6 additions & 4 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
ucc_rank_t peer;

task->super.super.status = UCC_INPROGRESS;
data_size = (size_t)task->args.src.info.count *
data_size = (size_t)(task->args.src.info.count / gsize) *
ucc_dt_size(task->args.src.info.datatype);
ucc_assert(task->args.src.info.count % gsize == 0);
if (data_size == 0) {
task->super.super.status = UCC_OK;
return UCC_OK;
Expand Down Expand Up @@ -256,11 +257,12 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task)
size_t count = task->args.dst.info.count;

if (UCC_IS_INPLACE(task->args)) {
src = (void*)((ptrdiff_t)task->args.dst.info.buffer +
count * ucc_dt_size(task->args.dst.info.datatype) * team->rank);
src = (void*)((ptrdiff_t)task->args.dst.info.buffer + (count / team->size) *
ucc_dt_size(task->args.dst.info.datatype) * team->rank);
}
task->super.super.status = UCC_INPROGRESS;
NCCLCHECK_GOTO(ncclAllGather(src, dst, count, dt, team->nccl_comm, stream),
NCCLCHECK_GOTO(ncclAllGather(src, dst, count / team->size, dt,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
status = ucc_mc_ee_event_post(stream, task->completed, UCC_EE_CUDA_STREAM);
exit_coll:
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/ucp/allgather/allgather_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ucc_status_t ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
ucc_memory_type_t rmem = task->args.dst.info.mem_type;
size_t count = task->args.dst.info.count;
ucc_datatype_t dt = task->args.dst.info.datatype;
size_t data_size = count * ucc_dt_size(dt);
size_t data_size = (count / team->size) * ucc_dt_size(dt);
ucc_rank_t sendto = (group_rank + 1) % group_size;
ucc_rank_t recvfrom = (group_rank - 1 + group_size) % group_size;
int step;
Expand Down Expand Up @@ -66,7 +66,7 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
ucc_memory_type_t smem = task->args.src.info.mem_type;
ucc_memory_type_t rmem = task->args.dst.info.mem_type;
ucc_datatype_t dt = task->args.dst.info.datatype;
size_t data_size = count * ucc_dt_size(dt);
size_t data_size = (count / team->size) * ucc_dt_size(dt);
ucc_status_t status;

task->super.super.status = UCC_INPROGRESS;
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/ucp/alltoall/alltoall_pairwise.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ ucc_status_t ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task)

posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoall_pairwise_num_posts;
nreqs = (posts > gsize || posts == 0) ? gsize : posts;
data_size = (size_t)task->args.src.info.count *
data_size = (size_t)(task->args.src.info.count / gsize) *
ucc_dt_size(task->args.src.info.datatype);
while ((task->send_posted < gsize || task->recv_posted < gsize) &&
(polls++ < task->n_polls)) {
Expand Down
95 changes: 49 additions & 46 deletions test/gtest/core/test_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,59 +8,62 @@ extern "C" {
}
#include "common/test_ucc.h"
#include "utils/ucc_math.h"
#include "utils/ucc_malloc.h"

using Param_0 = std::tuple<int, int, ucc_memory_type_t, int, gtest_ucc_inplace_t>;
using Param_1 = std::tuple<int, ucc_memory_type_t, int, gtest_ucc_inplace_t>;

class test_allgather : public UccCollArgs, public ucc::test
{
public:
void data_init(int nprocs, ucc_datatype_t dtype, size_t count,
UccCollCtxVec &ctxs)
{
ctxs.resize(nprocs);
for (auto r = 0; r < nprocs; r++) {
ucc_coll_args_t *coll = (ucc_coll_args_t*)
calloc(1, sizeof(ucc_coll_args_t));
ctxs[r] = (gtest_ucc_coll_ctx_t*)calloc(1, sizeof(gtest_ucc_coll_ctx_t));
ctxs[r]->args = coll;

coll->mask = 0;
coll->flags = 0;
coll->coll_type = UCC_COLL_TYPE_ALLGATHER;
coll->src.info.mem_type = mem_type;
coll->src.info.count = (ucc_count_t)count;
coll->src.info.datatype = dtype;
coll->dst.info.mem_type = mem_type;
coll->dst.info.count = (ucc_count_t)count;
coll->dst.info.datatype = dtype;

UCC_CHECK(ucc_mc_alloc(&ctxs[r]->init_buf,
ucc_dt_size(dtype) * count,
UCC_MEMORY_TYPE_HOST));
uint8_t *sbuf = (uint8_t*)ctxs[r]->init_buf;
for (int i = 0; i < ucc_dt_size(dtype) * count; i++) {
sbuf[i] = r;
}

ctxs[r]->rbuf_size = ucc_dt_size(dtype) * count * nprocs;
UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer, ctxs[r]->rbuf_size,
mem_type));
if (TEST_INPLACE == inplace) {
coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
UCC_CHECK(ucc_mc_memcpy((void*)((ptrdiff_t)coll->dst.info.buffer +
r * count * ucc_dt_size(dtype)),
ctxs[r]->init_buf, ucc_dt_size(dtype) * count,
mem_type, UCC_MEMORY_TYPE_HOST));
} else {
UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer,
ucc_dt_size(dtype) * count, mem_type));
UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[r]->init_buf,
ucc_dt_size(dtype) * count, mem_type,
UCC_MEMORY_TYPE_HOST));
}
}
void data_init(int nprocs, ucc_datatype_t dtype, size_t single_rank_count,
UccCollCtxVec &ctxs)
{
ctxs.resize(nprocs);
for (auto r = 0; r < nprocs; r++) {
ucc_coll_args_t *coll =
(ucc_coll_args_t *)calloc(1, sizeof(ucc_coll_args_t));
ctxs[r] =
(gtest_ucc_coll_ctx_t *)calloc(1, sizeof(gtest_ucc_coll_ctx_t));
ctxs[r]->args = coll;

coll->mask = 0;
coll->flags = 0;
coll->coll_type = UCC_COLL_TYPE_ALLGATHER;
coll->src.info.mem_type = mem_type;
coll->src.info.count = (ucc_count_t)single_rank_count;
coll->src.info.datatype = dtype;
coll->dst.info.mem_type = mem_type;
coll->dst.info.count = (ucc_count_t)single_rank_count * nprocs;
coll->dst.info.datatype = dtype;

ctxs[r]->init_buf =
ucc_malloc(ucc_dt_size(dtype) * single_rank_count, "init buf");
EXPECT_NE(ctxs[r]->init_buf, nullptr);
uint8_t *sbuf = (uint8_t *)ctxs[r]->init_buf;
for (int i = 0; i < ucc_dt_size(dtype) * single_rank_count; i++) {
sbuf[i] = r;
}

ctxs[r]->rbuf_size = ucc_dt_size(dtype) * single_rank_count * nprocs;
UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer, ctxs[r]->rbuf_size,
mem_type));
if (TEST_INPLACE == inplace) {
coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
UCC_CHECK(ucc_mc_memcpy(
(void *)((ptrdiff_t)coll->dst.info.buffer +
r * single_rank_count * ucc_dt_size(dtype)),
ctxs[r]->init_buf, ucc_dt_size(dtype) * single_rank_count,
mem_type, UCC_MEMORY_TYPE_HOST));
} else {
UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer,
ucc_dt_size(dtype) * single_rank_count, mem_type));
UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[r]->init_buf,
ucc_dt_size(dtype) * single_rank_count,
mem_type, UCC_MEMORY_TYPE_HOST));
}
}
}
void data_fini(UccCollCtxVec ctxs)
{
Expand Down
105 changes: 55 additions & 50 deletions test/gtest/core/test_alltoall.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,65 @@ extern "C" {
}
#include "common/test_ucc.h"
#include "utils/ucc_math.h"
#include "utils/ucc_malloc.h"

using Param_0 = std::tuple<int, int, ucc_memory_type_t, gtest_ucc_inplace_t, int>;
using Param_1 = std::tuple<int, ucc_memory_type_t, gtest_ucc_inplace_t, int>;

class test_alltoall : public UccCollArgs, public ucc::test
{
public:
void data_init(int nprocs, ucc_datatype_t dtype, size_t count,
UccCollCtxVec &ctxs)
{
ctxs.resize(nprocs);
for (auto i = 0; i < nprocs; i++) {
ucc_coll_args_t *coll = (ucc_coll_args_t*)
calloc(1, sizeof(ucc_coll_args_t));

ctxs[i] = (gtest_ucc_coll_ctx_t*)calloc(1, sizeof(gtest_ucc_coll_ctx_t));
ctxs[i]->args = coll;

coll->mask = 0;
coll->coll_type = UCC_COLL_TYPE_ALLTOALL;
coll->src.info.mem_type = mem_type;
coll->src.info.count = (ucc_count_t)count;
coll->src.info.datatype = dtype;
coll->dst.info.mem_type = mem_type;
coll->dst.info.count = (ucc_count_t)count;
coll->dst.info.datatype = dtype;

UCC_CHECK(ucc_mc_alloc(&ctxs[i]->init_buf,
ucc_dt_size(dtype) * count * nprocs,
UCC_MEMORY_TYPE_HOST));
for (int r = 0; r < nprocs; r++) {
size_t rank_size = ucc_dt_size(dtype) * count;
alltoallx_init_buf(r, i,
(uint8_t*)ctxs[i]->init_buf + r * rank_size,
rank_size);
}

UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer,
ucc_dt_size(dtype) * count * nprocs, mem_type));
if (TEST_INPLACE == inplace) {
coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
UCC_CHECK(ucc_mc_memcpy(coll->dst.info.buffer, ctxs[i]->init_buf,
ucc_dt_size(dtype) * count, mem_type,
UCC_MEMORY_TYPE_HOST));
} else {
UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer,
ucc_dt_size(dtype) * count * nprocs, mem_type));
UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[i]->init_buf,
ucc_dt_size(dtype) * count * nprocs, mem_type,
UCC_MEMORY_TYPE_HOST));
}
}
void data_init(int nprocs, ucc_datatype_t dtype, size_t single_rank_count,
UccCollCtxVec &ctxs)
{
ctxs.resize(nprocs);
for (auto i = 0; i < nprocs; i++) {
ucc_coll_args_t *coll =
(ucc_coll_args_t *)calloc(1, sizeof(ucc_coll_args_t));

ctxs[i] =
(gtest_ucc_coll_ctx_t *)calloc(1, sizeof(gtest_ucc_coll_ctx_t));
ctxs[i]->args = coll;

coll->mask = 0;
coll->coll_type = UCC_COLL_TYPE_ALLTOALL;
coll->src.info.mem_type = mem_type;
coll->src.info.count = (ucc_count_t)single_rank_count * nprocs;
coll->src.info.datatype = dtype;
coll->dst.info.mem_type = mem_type;
coll->dst.info.count = (ucc_count_t)single_rank_count * nprocs;
coll->dst.info.datatype = dtype;

ctxs[i]->init_buf = ucc_malloc(
ucc_dt_size(dtype) * single_rank_count * nprocs, "init buf");
EXPECT_NE(ctxs[i]->init_buf, nullptr);
for (int r = 0; r < nprocs; r++) {
size_t rank_size = ucc_dt_size(dtype) * single_rank_count;
alltoallx_init_buf(r, i,
(uint8_t *)ctxs[i]->init_buf + r * rank_size,
rank_size);
}

UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer,
ucc_dt_size(dtype) * single_rank_count * nprocs, mem_type));
if (TEST_INPLACE == inplace) {
coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS;
coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE;
UCC_CHECK(ucc_mc_memcpy(coll->dst.info.buffer, ctxs[i]->init_buf,
ucc_dt_size(dtype) * single_rank_count,
mem_type, UCC_MEMORY_TYPE_HOST));
} else {
UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer,
ucc_dt_size(dtype) * single_rank_count * nprocs, mem_type));
UCC_CHECK(
ucc_mc_memcpy(coll->src.info.buffer, ctxs[i]->init_buf,
ucc_dt_size(dtype) * single_rank_count * nprocs,
mem_type, UCC_MEMORY_TYPE_HOST));
}
}
}

void data_fini(UccCollCtxVec ctxs)
{
for (gtest_ucc_coll_ctx_t* ctx : ctxs) {
Expand All @@ -83,10 +88,10 @@ class test_alltoall : public UccCollArgs, public ucc::test
if (UCC_MEMORY_TYPE_HOST != mem_type) {
for (int r = 0; r < ctxs.size(); r++) {
size_t buf_size =
ucc_dt_size(ctxs[r]->args->dst.info.datatype) *
(size_t)ctxs[r]->args->dst.info.count * ctxs.size();
UCC_CHECK(ucc_mc_alloc((void**)&dsts[r], buf_size,
UCC_MEMORY_TYPE_HOST));
ucc_dt_size(ctxs[r]->args->dst.info.datatype) *
(size_t)ctxs[r]->args->dst.info.count;
dsts[r] = (uint8_t *) ucc_malloc(buf_size, "dsts buf");
EXPECT_NE(dsts[r], nullptr);
UCC_CHECK(ucc_mc_memcpy(dsts[r], ctxs[r]->args->dst.info.buffer,
buf_size, UCC_MEMORY_TYPE_HOST, mem_type));
}
Expand All @@ -100,7 +105,7 @@ class test_alltoall : public UccCollArgs, public ucc::test

for (int i = 0; i < ctxs.size(); i++) {
size_t rank_size = ucc_dt_size(coll->dst.info.datatype) *
(size_t)coll->dst.info.count;
(size_t)(coll->dst.info.count / ctxs.size());
ASSERT_EQ(0,
alltoallx_validate_buf(i, r,
(uint8_t*)dsts[r] + rank_size * i,
Expand Down
30 changes: 16 additions & 14 deletions test/mpi/test_allgather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TestAllgather::TestAllgather(size_t _msgsize, ucc_test_mpi_inplace_t _inplace,
TestCase(_team, _mt, _msgsize, _inplace, _max_size)
{
size_t dt_size = ucc_dt_size(TEST_DT);
size_t count = _msgsize/dt_size;
size_t single_rank_count = _msgsize / dt_size;
int rank, size;
MPI_Comm_rank(team.comm, &rank);
MPI_Comm_size(team.comm, &size);
Expand All @@ -33,36 +33,38 @@ TestAllgather::TestAllgather(size_t _msgsize, ucc_test_mpi_inplace_t _inplace,
UCC_CHECK(ucc_mc_alloc(&check_rbuf, _msgsize*size, UCC_MEMORY_TYPE_HOST));
if (TEST_NO_INPLACE == inplace) {
UCC_CHECK(ucc_mc_alloc(&sbuf, _msgsize, _mt));
init_buffer(sbuf, count, TEST_DT, _mt, rank);
init_buffer(sbuf, single_rank_count, TEST_DT, _mt, rank);
UCC_ALLOC_COPY_BUF(check_sbuf, UCC_MEMORY_TYPE_HOST, sbuf, _mt, _msgsize);
} else {
args.mask = UCC_COLL_ARGS_FIELD_FLAGS;
args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
init_buffer((void*)((ptrdiff_t)rbuf + rank*count*dt_size),
count, TEST_DT, _mt, rank);
init_buffer((void*)((ptrdiff_t)check_rbuf + rank*count*dt_size),
count, TEST_DT, UCC_MEMORY_TYPE_HOST, rank);
init_buffer(
(void *)((ptrdiff_t)rbuf + rank * single_rank_count * dt_size),
single_rank_count, TEST_DT, _mt, rank);
init_buffer((void *)((ptrdiff_t)check_rbuf +
rank * single_rank_count * dt_size),
single_rank_count, TEST_DT, UCC_MEMORY_TYPE_HOST, rank);
}

args.src.info.buffer = sbuf;
args.src.info.count = count;
args.src.info.count = single_rank_count;
args.src.info.datatype = TEST_DT;
args.src.info.mem_type = _mt;
args.dst.info.buffer = rbuf;
args.dst.info.count = count;
args.dst.info.count = single_rank_count * size;
args.dst.info.datatype = TEST_DT;
args.dst.info.mem_type = _mt;
UCC_CHECK(ucc_collective_init(&args, &req, team.team));
}

ucc_status_t TestAllgather::check()
{
size_t count = args.dst.info.count;
int size;
MPI_Comm_size(team.comm, &size);
size_t single_rank_count = args.dst.info.count / size;
MPI_Datatype dt = ucc_dt_to_mpi(TEST_DT);
int size;

MPI_Comm_size(team.comm, &size);
MPI_Allgather(inplace ? MPI_IN_PLACE : check_sbuf, count, dt,
check_rbuf, count, dt, team.comm);
return compare_buffers(rbuf, check_rbuf, count*size, TEST_DT, mem_type);
MPI_Allgather(inplace ? MPI_IN_PLACE : check_sbuf, single_rank_count, dt,
check_rbuf, single_rank_count, dt, team.comm);
return compare_buffers(rbuf, check_rbuf, single_rank_count * size, TEST_DT, mem_type);
}
Loading

0 comments on commit ca30502

Please sign in to comment.