From 8fc9a817af279bac183d3bf3717d513428d3c22b Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:35:11 +0800 Subject: [PATCH] support r to s unbalanced split (#56149) --- .../auto_parallel/r_to_s_reshard_function.cc | 10 ++--- .../auto_parallel/reshard_split_functor.cc | 45 ++++++------------- .../auto_parallel/reshard_utils.cc | 9 ++++ .../distributed/auto_parallel/reshard_utils.h | 5 +++ paddle/phi/kernels/split_kernel.h | 30 +++++++++---- test/cpp/auto_parallel/test_reshard_r_to_s.cc | 30 +++++++++++++ 6 files changed, 83 insertions(+), 46 deletions(-) diff --git a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc index 9556d3912fc82..5ddd238883cf4 100644 --- a/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.cc @@ -78,11 +78,9 @@ std::shared_ptr RToSReshardFunction::Eval( << " There will have " << num_of_process << " process participate in."; - // TODO(liyurui): Consider the tensor can not be balanced split, - // for example, the shape of tensor is {6} but want to split it by 4 - // process. - IntArray sections(std::vector( - num_of_process, in.dims()[split_axis] / num_of_process)); + std::vector split_num_vec = + BalancedSplit(in.dims()[split_axis], num_of_process); + IntArray sections(split_num_vec); std::vector split_out_vec = ReshardSplitFunctor( *dev_ctx, in_physical_tensor_cur_rank, sections, split_axis); @@ -90,6 +88,8 @@ std::shared_ptr RToSReshardFunction::Eval( VLOG(3) << "The current process will remain the idx " << coord_in_mesh[mesh_axis] << " piece of tensor"; out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; + VLOG(3) << "The shape of physical tensor after split is " + << out_physical_tensor_cur_rank.dims(); return std::make_shared( std::make_shared(out_physical_tensor_cur_rank), diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc index 189738b81367f..4d0818eed4c0a 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc @@ -27,45 +27,26 @@ std::vector ReshardSplitFunctor(const DeviceContext& dev_ctx, const DenseTensor& input, const IntArray& sections, int64_t axis) { - size_t out_number = sections.size(); - std::vector result(out_number); - - std::vector out_meta; - std::vector out_meta_ptr; - - out_meta.reserve(out_number); - out_meta_ptr.reserve(out_number); - for (size_t i = 0; i < out_number; ++i) { - out_meta.emplace_back(result[i]); - out_meta_ptr.emplace_back(&out_meta.back()); - } - SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr); - - std::vector outs; - for (size_t i = 0; i < out_number; ++i) { - outs.emplace_back(&result[i]); - } + std::vector result; if (phi::CPUContext::classof(&dev_ctx)) { - PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] { - SplitKernel( - static_cast(dev_ctx), - input, - sections, - axis, - outs); + PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] { + Split(static_cast(dev_ctx), + input, + sections, + axis, + &result); })); return result; } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (phi::GPUContext::classof(&dev_ctx)) { - PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] { - SplitKernel( - static_cast(dev_ctx), - input, - sections, - axis, - outs); + PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] { + Split(static_cast(dev_ctx), + input, + sections, + axis, + &result); })); return result; } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc index 88f75ee2cbe0a..10d0272209ddf 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.cc @@ -189,5 +189,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, return comm_context; } +std::vector BalancedSplit(int64_t total_nums, int64_t num_of_pieces) { + std::vector result(num_of_pieces, total_nums / num_of_pieces); + int64_t remain_nums = total_nums % num_of_pieces; + for (int64_t i = 0; i < remain_nums; ++i) { + result[i] += 1; + } + return result; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h index 7a78bac03140b..8628e8ec516d8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/reshard_utils.h @@ -69,5 +69,10 @@ uint16_t GetMasterPort(); std::shared_ptr CreateOrGetGlobalTCPStore(); +// If given a number, balance split it to multiple pieces. +// For example, the input value is 12, split it to 5 pieces, then return +// {3, 3, 2, 2, 2}. +std::vector BalancedSplit(int64_t total_nums, int64_t num_of_pieces); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/kernels/split_kernel.h b/paddle/phi/kernels/split_kernel.h index 688a71e7c74de..7a6b7173961ac 100644 --- a/paddle/phi/kernels/split_kernel.h +++ b/paddle/phi/kernels/split_kernel.h @@ -50,31 +50,43 @@ void SplitWithNumStridedKernel(const Context& dev_ctx, std::vector out); template -std::vector Split(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& sections, - const Scalar& axis) { - size_t out_number; - out_number = sections.GetData().size(); +void Split(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& sections, + const Scalar& axis, + std::vector* result) { + size_t out_number = sections.GetData().size(); std::vector out_meta; std::vector out_meta_ptr; out_meta.reserve(out_number); out_meta_ptr.reserve(out_number); - std::vector result(out_number); + result->resize(out_number); for (size_t i = 0; i < out_number; ++i) { - out_meta.emplace_back(&result[i]); + out_meta.emplace_back(&result->at(i)); out_meta_ptr.push_back(&out_meta.back()); } SplitInferMeta(x, sections, axis, out_meta_ptr); std::vector outs; outs.reserve(out_meta.size()); for (size_t i = 0; i < out_meta.size(); ++i) { - outs.push_back(&result[i]); + outs.push_back(&result->at(i)); } SplitKernel(dev_ctx, x, sections, axis, outs); +} + +template +std::vector Split(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& sections, + const Scalar& axis) { + size_t out_number = sections.GetData().size(); + std::vector result(out_number); + + Split(dev_ctx, x, sections, axis, &result); + return result; } diff --git a/test/cpp/auto_parallel/test_reshard_r_to_s.cc b/test/cpp/auto_parallel/test_reshard_r_to_s.cc index 32c697369fc2f..dd354484a4f2c 100644 --- a/test/cpp/auto_parallel/test_reshard_r_to_s.cc +++ b/test/cpp/auto_parallel/test_reshard_r_to_s.cc @@ -121,6 +121,36 @@ TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) { CHECK_EQ(output->dims(), DDim({6, 2})); } +TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh_unbalance_split) { + setenv("PADDLE_TRAINER_ID", "1", 1); + + std::vector tensor_shape = {6, 8}; + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* context = reinterpret_cast(pool.Get(phi::CPUPlace())); + + std::vector mesh_shape = {4}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x"}; + ProcessMesh mesh(mesh_shape, process_ids, dim_names); + + std::shared_ptr input = + ConstructReplicatedDistCPU(context, tensor_shape, mesh); + + std::shared_ptr out_dist_attr = + std::make_shared(tensor_shape); + std::vector out_dims_mapping = {0, -1}; + out_dist_attr->set_dims_mapping(out_dims_mapping); + out_dist_attr->set_process_mesh(mesh); + + RToSReshardFunction r_to_s_func; + std::shared_ptr output = + r_to_s_func.Eval(context, *input, out_dist_attr); + + CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true); + CHECK_EQ(output->numel(), 16); + CHECK_EQ(output->dims(), DDim({2, 8})); +} + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) { setenv("PADDLE_TRAINER_ID", "0", 0);