diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index d663598786..e6e0c6b04b 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -80,8 +80,9 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task) ucc_rank_t peer; task->super.super.status = UCC_INPROGRESS; - data_size = (size_t)task->args.src.info.count * + data_size = (size_t)(task->args.src.info.count / gsize) * ucc_dt_size(task->args.src.info.datatype); + ucc_assert(task->args.src.info.count % gsize == 0); if (data_size == 0) { task->super.super.status = UCC_OK; return UCC_OK; @@ -256,11 +257,12 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task) size_t count = task->args.dst.info.count; if (UCC_IS_INPLACE(task->args)) { - src = (void*)((ptrdiff_t)task->args.dst.info.buffer + - count * ucc_dt_size(task->args.dst.info.datatype) * team->rank); + src = (void*)((ptrdiff_t)task->args.dst.info.buffer + (count / team->size) * + ucc_dt_size(task->args.dst.info.datatype) * team->rank); } task->super.super.status = UCC_INPROGRESS; - NCCLCHECK_GOTO(ncclAllGather(src, dst, count, dt, team->nccl_comm, stream), + NCCLCHECK_GOTO(ncclAllGather(src, dst, count / team->size, dt, + team->nccl_comm, stream), exit_coll, status, UCC_TL_TEAM_LIB(team)); status = ucc_mc_ee_event_post(stream, task->completed, UCC_EE_CUDA_STREAM); exit_coll: diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index bb4ee6320c..c50ab21420 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -23,7 +23,7 @@ ucc_status_t ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) ucc_memory_type_t rmem = task->args.dst.info.mem_type; size_t count = task->args.dst.info.count; ucc_datatype_t dt = task->args.dst.info.datatype; - size_t data_size = count * ucc_dt_size(dt); + size_t data_size = (count / team->size) * ucc_dt_size(dt); ucc_rank_t sendto = (group_rank + 1) % group_size; ucc_rank_t recvfrom = (group_rank - 1 + group_size) % group_size; int step; @@ -66,7 +66,7 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) ucc_memory_type_t smem = task->args.src.info.mem_type; ucc_memory_type_t rmem = task->args.dst.info.mem_type; ucc_datatype_t dt = task->args.dst.info.datatype; - size_t data_size = count * ucc_dt_size(dt); + size_t data_size = (count / team->size) * ucc_dt_size(dt); ucc_status_t status; task->super.super.status = UCC_INPROGRESS; diff --git a/src/components/tl/ucp/alltoall/alltoall_pairwise.c b/src/components/tl/ucp/alltoall/alltoall_pairwise.c index 8bb869efa5..57eee8f128 100644 --- a/src/components/tl/ucp/alltoall/alltoall_pairwise.c +++ b/src/components/tl/ucp/alltoall/alltoall_pairwise.c @@ -40,7 +40,7 @@ ucc_status_t ucc_tl_ucp_alltoall_pairwise_progress(ucc_coll_task_t *coll_task) posts = UCC_TL_UCP_TEAM_LIB(team)->cfg.alltoall_pairwise_num_posts; nreqs = (posts > gsize || posts == 0) ? gsize : posts; - data_size = (size_t)task->args.src.info.count * + data_size = (size_t)(task->args.src.info.count / gsize) * ucc_dt_size(task->args.src.info.datatype); while ((task->send_posted < gsize || task->recv_posted < gsize) && (polls++ < task->n_polls)) { diff --git a/test/gtest/core/test_allgather.cc b/test/gtest/core/test_allgather.cc index 0645019db7..07d87b60b4 100644 --- a/test/gtest/core/test_allgather.cc +++ b/test/gtest/core/test_allgather.cc @@ -8,6 +8,7 @@ extern "C" { } #include "common/test_ucc.h" #include "utils/ucc_math.h" +#include "utils/ucc_malloc.h" using Param_0 = std::tuple; using Param_1 = std::tuple; @@ -15,52 +16,54 @@ using Param_1 = std::tuple; class test_allgather : public UccCollArgs, public ucc::test { public: - void data_init(int nprocs, ucc_datatype_t dtype, size_t count, - UccCollCtxVec &ctxs) - { - ctxs.resize(nprocs); - for (auto r = 0; r < nprocs; r++) { - ucc_coll_args_t *coll = (ucc_coll_args_t*) - calloc(1, sizeof(ucc_coll_args_t)); - ctxs[r] = (gtest_ucc_coll_ctx_t*)calloc(1, sizeof(gtest_ucc_coll_ctx_t)); - ctxs[r]->args = coll; - - coll->mask = 0; - coll->flags = 0; - coll->coll_type = UCC_COLL_TYPE_ALLGATHER; - coll->src.info.mem_type = mem_type; - coll->src.info.count = (ucc_count_t)count; - coll->src.info.datatype = dtype; - coll->dst.info.mem_type = mem_type; - coll->dst.info.count = (ucc_count_t)count; - coll->dst.info.datatype = dtype; - - UCC_CHECK(ucc_mc_alloc(&ctxs[r]->init_buf, - ucc_dt_size(dtype) * count, - UCC_MEMORY_TYPE_HOST)); - uint8_t *sbuf = (uint8_t*)ctxs[r]->init_buf; - for (int i = 0; i < ucc_dt_size(dtype) * count; i++) { - sbuf[i] = r; - } - - ctxs[r]->rbuf_size = ucc_dt_size(dtype) * count * nprocs; - UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer, ctxs[r]->rbuf_size, - mem_type)); - if (TEST_INPLACE == inplace) { - coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; - UCC_CHECK(ucc_mc_memcpy((void*)((ptrdiff_t)coll->dst.info.buffer + - r * count * ucc_dt_size(dtype)), - ctxs[r]->init_buf, ucc_dt_size(dtype) * count, - mem_type, UCC_MEMORY_TYPE_HOST)); - } else { - UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer, - ucc_dt_size(dtype) * count, mem_type)); - UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[r]->init_buf, - ucc_dt_size(dtype) * count, mem_type, - UCC_MEMORY_TYPE_HOST)); - } - } + void data_init(int nprocs, ucc_datatype_t dtype, size_t single_rank_count, + UccCollCtxVec &ctxs) + { + ctxs.resize(nprocs); + for (auto r = 0; r < nprocs; r++) { + ucc_coll_args_t *coll = + (ucc_coll_args_t *)calloc(1, sizeof(ucc_coll_args_t)); + ctxs[r] = + (gtest_ucc_coll_ctx_t *)calloc(1, sizeof(gtest_ucc_coll_ctx_t)); + ctxs[r]->args = coll; + + coll->mask = 0; + coll->flags = 0; + coll->coll_type = UCC_COLL_TYPE_ALLGATHER; + coll->src.info.mem_type = mem_type; + coll->src.info.count = (ucc_count_t)single_rank_count; + coll->src.info.datatype = dtype; + coll->dst.info.mem_type = mem_type; + coll->dst.info.count = (ucc_count_t)single_rank_count * nprocs; + coll->dst.info.datatype = dtype; + + ctxs[r]->init_buf = + ucc_malloc(ucc_dt_size(dtype) * single_rank_count, "init buf"); + EXPECT_NE(ctxs[r]->init_buf, nullptr); + uint8_t *sbuf = (uint8_t *)ctxs[r]->init_buf; + for (int i = 0; i < ucc_dt_size(dtype) * single_rank_count; i++) { + sbuf[i] = r; + } + + ctxs[r]->rbuf_size = ucc_dt_size(dtype) * single_rank_count * nprocs; + UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer, ctxs[r]->rbuf_size, + mem_type)); + if (TEST_INPLACE == inplace) { + coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; + UCC_CHECK(ucc_mc_memcpy( + (void *)((ptrdiff_t)coll->dst.info.buffer + + r * single_rank_count * ucc_dt_size(dtype)), + ctxs[r]->init_buf, ucc_dt_size(dtype) * single_rank_count, + mem_type, UCC_MEMORY_TYPE_HOST)); + } else { + UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer, + ucc_dt_size(dtype) * single_rank_count, mem_type)); + UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[r]->init_buf, + ucc_dt_size(dtype) * single_rank_count, + mem_type, UCC_MEMORY_TYPE_HOST)); + } + } } void data_fini(UccCollCtxVec ctxs) { diff --git a/test/gtest/core/test_alltoall.cc b/test/gtest/core/test_alltoall.cc index a62ba16834..ea7620c4bf 100644 --- a/test/gtest/core/test_alltoall.cc +++ b/test/gtest/core/test_alltoall.cc @@ -8,6 +8,7 @@ extern "C" { } #include "common/test_ucc.h" #include "utils/ucc_math.h" +#include "utils/ucc_malloc.h" using Param_0 = std::tuple; using Param_1 = std::tuple; @@ -15,53 +16,57 @@ using Param_1 = std::tuple; class test_alltoall : public UccCollArgs, public ucc::test { public: - void data_init(int nprocs, ucc_datatype_t dtype, size_t count, - UccCollCtxVec &ctxs) - { - ctxs.resize(nprocs); - for (auto i = 0; i < nprocs; i++) { - ucc_coll_args_t *coll = (ucc_coll_args_t*) - calloc(1, sizeof(ucc_coll_args_t)); - - ctxs[i] = (gtest_ucc_coll_ctx_t*)calloc(1, sizeof(gtest_ucc_coll_ctx_t)); - ctxs[i]->args = coll; - - coll->mask = 0; - coll->coll_type = UCC_COLL_TYPE_ALLTOALL; - coll->src.info.mem_type = mem_type; - coll->src.info.count = (ucc_count_t)count; - coll->src.info.datatype = dtype; - coll->dst.info.mem_type = mem_type; - coll->dst.info.count = (ucc_count_t)count; - coll->dst.info.datatype = dtype; - - UCC_CHECK(ucc_mc_alloc(&ctxs[i]->init_buf, - ucc_dt_size(dtype) * count * nprocs, - UCC_MEMORY_TYPE_HOST)); - for (int r = 0; r < nprocs; r++) { - size_t rank_size = ucc_dt_size(dtype) * count; - alltoallx_init_buf(r, i, - (uint8_t*)ctxs[i]->init_buf + r * rank_size, - rank_size); - } - UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer, - ucc_dt_size(dtype) * count * nprocs, mem_type)); - if (TEST_INPLACE == inplace) { - coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; - UCC_CHECK(ucc_mc_memcpy(coll->dst.info.buffer, ctxs[i]->init_buf, - ucc_dt_size(dtype) * count, mem_type, - UCC_MEMORY_TYPE_HOST)); - } else { - UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer, - ucc_dt_size(dtype) * count * nprocs, mem_type)); - UCC_CHECK(ucc_mc_memcpy(coll->src.info.buffer, ctxs[i]->init_buf, - ucc_dt_size(dtype) * count * nprocs, mem_type, - UCC_MEMORY_TYPE_HOST)); - } - } + void data_init(int nprocs, ucc_datatype_t dtype, size_t single_rank_count, + UccCollCtxVec &ctxs) + { + ctxs.resize(nprocs); + for (auto i = 0; i < nprocs; i++) { + ucc_coll_args_t *coll = + (ucc_coll_args_t *)calloc(1, sizeof(ucc_coll_args_t)); + + ctxs[i] = + (gtest_ucc_coll_ctx_t *)calloc(1, sizeof(gtest_ucc_coll_ctx_t)); + ctxs[i]->args = coll; + + coll->mask = 0; + coll->coll_type = UCC_COLL_TYPE_ALLTOALL; + coll->src.info.mem_type = mem_type; + coll->src.info.count = (ucc_count_t)single_rank_count * nprocs; + coll->src.info.datatype = dtype; + coll->dst.info.mem_type = mem_type; + coll->dst.info.count = (ucc_count_t)single_rank_count * nprocs; + coll->dst.info.datatype = dtype; + + ctxs[i]->init_buf = ucc_malloc( + ucc_dt_size(dtype) * single_rank_count * nprocs, "init buf"); + EXPECT_NE(ctxs[i]->init_buf, nullptr); + for (int r = 0; r < nprocs; r++) { + size_t rank_size = ucc_dt_size(dtype) * single_rank_count; + alltoallx_init_buf(r, i, + (uint8_t *)ctxs[i]->init_buf + r * rank_size, + rank_size); + } + + UCC_CHECK(ucc_mc_alloc(&coll->dst.info.buffer, + ucc_dt_size(dtype) * single_rank_count * nprocs, mem_type)); + if (TEST_INPLACE == inplace) { + coll->mask |= UCC_COLL_ARGS_FIELD_FLAGS; + coll->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; + UCC_CHECK(ucc_mc_memcpy(coll->dst.info.buffer, ctxs[i]->init_buf, + ucc_dt_size(dtype) * single_rank_count, + mem_type, UCC_MEMORY_TYPE_HOST)); + } else { + UCC_CHECK(ucc_mc_alloc(&coll->src.info.buffer, + ucc_dt_size(dtype) * single_rank_count * nprocs, mem_type)); + UCC_CHECK( + ucc_mc_memcpy(coll->src.info.buffer, ctxs[i]->init_buf, + ucc_dt_size(dtype) * single_rank_count * nprocs, + mem_type, UCC_MEMORY_TYPE_HOST)); + } + } } + void data_fini(UccCollCtxVec ctxs) { for (gtest_ucc_coll_ctx_t* ctx : ctxs) { @@ -83,10 +88,10 @@ class test_alltoall : public UccCollArgs, public ucc::test if (UCC_MEMORY_TYPE_HOST != mem_type) { for (int r = 0; r < ctxs.size(); r++) { size_t buf_size = - ucc_dt_size(ctxs[r]->args->dst.info.datatype) * - (size_t)ctxs[r]->args->dst.info.count * ctxs.size(); - UCC_CHECK(ucc_mc_alloc((void**)&dsts[r], buf_size, - UCC_MEMORY_TYPE_HOST)); + ucc_dt_size(ctxs[r]->args->dst.info.datatype) * + (size_t)ctxs[r]->args->dst.info.count; + dsts[r] = (uint8_t *) ucc_malloc(buf_size, "dsts buf"); + EXPECT_NE(dsts[r], nullptr); UCC_CHECK(ucc_mc_memcpy(dsts[r], ctxs[r]->args->dst.info.buffer, buf_size, UCC_MEMORY_TYPE_HOST, mem_type)); } @@ -100,7 +105,7 @@ class test_alltoall : public UccCollArgs, public ucc::test for (int i = 0; i < ctxs.size(); i++) { size_t rank_size = ucc_dt_size(coll->dst.info.datatype) * - (size_t)coll->dst.info.count; + (size_t)(coll->dst.info.count / ctxs.size()); ASSERT_EQ(0, alltoallx_validate_buf(i, r, (uint8_t*)dsts[r] + rank_size * i, diff --git a/test/mpi/test_allgather.cc b/test/mpi/test_allgather.cc index 8991f8454c..daee3e05be 100644 --- a/test/mpi/test_allgather.cc +++ b/test/mpi/test_allgather.cc @@ -15,7 +15,7 @@ TestAllgather::TestAllgather(size_t _msgsize, ucc_test_mpi_inplace_t _inplace, TestCase(_team, _mt, _msgsize, _inplace, _max_size) { size_t dt_size = ucc_dt_size(TEST_DT); - size_t count = _msgsize/dt_size; + size_t single_rank_count = _msgsize / dt_size; int rank, size; MPI_Comm_rank(team.comm, &rank); MPI_Comm_size(team.comm, &size); @@ -33,23 +33,25 @@ TestAllgather::TestAllgather(size_t _msgsize, ucc_test_mpi_inplace_t _inplace, UCC_CHECK(ucc_mc_alloc(&check_rbuf, _msgsize*size, UCC_MEMORY_TYPE_HOST)); if (TEST_NO_INPLACE == inplace) { UCC_CHECK(ucc_mc_alloc(&sbuf, _msgsize, _mt)); - init_buffer(sbuf, count, TEST_DT, _mt, rank); + init_buffer(sbuf, single_rank_count, TEST_DT, _mt, rank); UCC_ALLOC_COPY_BUF(check_sbuf, UCC_MEMORY_TYPE_HOST, sbuf, _mt, _msgsize); } else { args.mask = UCC_COLL_ARGS_FIELD_FLAGS; args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - init_buffer((void*)((ptrdiff_t)rbuf + rank*count*dt_size), - count, TEST_DT, _mt, rank); - init_buffer((void*)((ptrdiff_t)check_rbuf + rank*count*dt_size), - count, TEST_DT, UCC_MEMORY_TYPE_HOST, rank); + init_buffer( + (void *)((ptrdiff_t)rbuf + rank * single_rank_count * dt_size), + single_rank_count, TEST_DT, _mt, rank); + init_buffer((void *)((ptrdiff_t)check_rbuf + + rank * single_rank_count * dt_size), + single_rank_count, TEST_DT, UCC_MEMORY_TYPE_HOST, rank); } args.src.info.buffer = sbuf; - args.src.info.count = count; + args.src.info.count = single_rank_count; args.src.info.datatype = TEST_DT; args.src.info.mem_type = _mt; args.dst.info.buffer = rbuf; - args.dst.info.count = count; + args.dst.info.count = single_rank_count * size; args.dst.info.datatype = TEST_DT; args.dst.info.mem_type = _mt; UCC_CHECK(ucc_collective_init(&args, &req, team.team)); @@ -57,12 +59,12 @@ TestAllgather::TestAllgather(size_t _msgsize, ucc_test_mpi_inplace_t _inplace, ucc_status_t TestAllgather::check() { - size_t count = args.dst.info.count; + int size; + MPI_Comm_size(team.comm, &size); + size_t single_rank_count = args.dst.info.count / size; MPI_Datatype dt = ucc_dt_to_mpi(TEST_DT); - int size; - MPI_Comm_size(team.comm, &size); - MPI_Allgather(inplace ? MPI_IN_PLACE : check_sbuf, count, dt, - check_rbuf, count, dt, team.comm); - return compare_buffers(rbuf, check_rbuf, count*size, TEST_DT, mem_type); + MPI_Allgather(inplace ? MPI_IN_PLACE : check_sbuf, single_rank_count, dt, + check_rbuf, single_rank_count, dt, team.comm); + return compare_buffers(rbuf, check_rbuf, single_rank_count * size, TEST_DT, mem_type); } diff --git a/test/mpi/test_alltoall.cc b/test/mpi/test_alltoall.cc index 0ba59c45ec..7019e6881a 100644 --- a/test/mpi/test_alltoall.cc +++ b/test/mpi/test_alltoall.cc @@ -15,7 +15,7 @@ TestAlltoall::TestAlltoall(size_t _msgsize, ucc_test_mpi_inplace_t _inplace, TestCase(_team, _mt, _msgsize, _inplace, _max_size) { size_t dt_size = ucc_dt_size(TEST_DT); - size_t count = _msgsize/dt_size; + size_t single_rank_count = _msgsize / dt_size; int rank; int nprocs; @@ -38,23 +38,24 @@ TestAlltoall::TestAlltoall(size_t _msgsize, ucc_test_mpi_inplace_t _inplace, UCC_CHECK(ucc_mc_alloc(&check_rbuf, _msgsize * nprocs, UCC_MEMORY_TYPE_HOST)); if (TEST_NO_INPLACE == inplace) { UCC_CHECK(ucc_mc_alloc(&sbuf, _msgsize * nprocs, _mt)); - init_buffer(sbuf, count * nprocs, TEST_DT, _mt, rank); + init_buffer(sbuf, single_rank_count * nprocs, TEST_DT, _mt, rank); UCC_ALLOC_COPY_BUF(check_sbuf, UCC_MEMORY_TYPE_HOST, sbuf, _mt, _msgsize * nprocs); } else { args.mask = UCC_COLL_ARGS_FIELD_FLAGS; args.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - init_buffer(rbuf, count * nprocs, TEST_DT, _mt, rank); - init_buffer(check_rbuf, count * nprocs, TEST_DT, UCC_MEMORY_TYPE_HOST, rank); + init_buffer(rbuf, single_rank_count * nprocs, TEST_DT, _mt, rank); + init_buffer(check_rbuf, single_rank_count * nprocs, TEST_DT, + UCC_MEMORY_TYPE_HOST, rank); } args.src.info.buffer = sbuf; - args.src.info.count = count; + args.src.info.count = single_rank_count * nprocs; args.src.info.datatype = TEST_DT; args.src.info.mem_type = _mt; args.dst.info.buffer = rbuf; - args.dst.info.count = count; + args.dst.info.count = single_rank_count * nprocs; args.dst.info.datatype = TEST_DT; args.dst.info.mem_type = _mt; UCC_CHECK(ucc_collective_init(&args, &req, team.team)); @@ -62,8 +63,10 @@ TestAlltoall::TestAlltoall(size_t _msgsize, ucc_test_mpi_inplace_t _inplace, ucc_status_t TestAlltoall::check() { - size_t count = args.src.info.count; - MPI_Alltoall(inplace ? MPI_IN_PLACE : check_sbuf, count, ucc_dt_to_mpi(TEST_DT), - check_rbuf, count, ucc_dt_to_mpi(TEST_DT), team.comm); - return compare_buffers(rbuf, check_rbuf, count, TEST_DT, mem_type); + int size; + MPI_Comm_size(team.comm, &size); + size_t single_rank_count = args.src.info.count / size; + MPI_Alltoall(inplace ? MPI_IN_PLACE : check_sbuf, single_rank_count, ucc_dt_to_mpi(TEST_DT), + check_rbuf, single_rank_count, ucc_dt_to_mpi(TEST_DT), team.comm); + return compare_buffers(rbuf, check_rbuf, single_rank_count * size, TEST_DT, mem_type); }