Skip to content

Commit

Permalink
support r to s unbalanced split (#56149)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio authored Aug 14, 2023
1 parent 476bc13 commit 8fc9a81
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,18 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
<< " There will have " << num_of_process
<< " process participate in.";

// TODO(liyurui): Consider the tensor can not be balanced split,
// for example, the shape of tensor is {6} but want to split it by 4
// process.
IntArray sections(std::vector<int64_t>(
num_of_process, in.dims()[split_axis] / num_of_process));
std::vector<int64_t> split_num_vec =
BalancedSplit(in.dims()[split_axis], num_of_process);
IntArray sections(split_num_vec);

std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);

VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];
VLOG(3) << "The shape of physical tensor after split is "
<< out_physical_tensor_cur_rank.dims();

return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_physical_tensor_cur_rank),
Expand Down
45 changes: 13 additions & 32 deletions paddle/phi/core/distributed/auto_parallel/reshard_split_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,26 @@ std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis) {
size_t out_number = sections.size();
std::vector<DenseTensor> result(out_number);

std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;

out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
out_meta.emplace_back(result[i]);
out_meta_ptr.emplace_back(&out_meta.back());
}
SplitInferMeta(phi::MetaTensor(input), sections, axis, out_meta_ptr);

std::vector<DenseTensor*> outs;
for (size_t i = 0; i < out_number; ++i) {
outs.emplace_back(&result[i]);
}
std::vector<DenseTensor> result;

if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "SplitKernel", ([&] {
SplitKernel<data_t>(
static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
outs);
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
return comm_context;
}

std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) {
std::vector<int64_t> result(num_of_pieces, total_nums / num_of_pieces);
int64_t remain_nums = total_nums % num_of_pieces;
for (int64_t i = 0; i < remain_nums; ++i) {
result[i] += 1;
}
return result;
}

} // namespace distributed
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,10 @@ uint16_t GetMasterPort();

std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();

// If given a number, balance split it to multiple pieces.
// For example, the input value is 12, split it to 5 pieces, then return
// {3, 3, 2, 2, 2}.
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);

} // namespace distributed
} // namespace phi
30 changes: 21 additions & 9 deletions paddle/phi/kernels/split_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,43 @@ void SplitWithNumStridedKernel(const Context& dev_ctx,
std::vector<DenseTensor*> out);

template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis) {
size_t out_number;
out_number = sections.GetData().size();
void Split(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis,
std::vector<DenseTensor>* result) {
size_t out_number = sections.GetData().size();

std::vector<MetaTensor> out_meta;
std::vector<MetaTensor*> out_meta_ptr;
out_meta.reserve(out_number);
out_meta_ptr.reserve(out_number);
std::vector<DenseTensor> result(out_number);
result->resize(out_number);

for (size_t i = 0; i < out_number; ++i) {
out_meta.emplace_back(&result[i]);
out_meta.emplace_back(&result->at(i));
out_meta_ptr.push_back(&out_meta.back());
}
SplitInferMeta(x, sections, axis, out_meta_ptr);
std::vector<DenseTensor*> outs;
outs.reserve(out_meta.size());
for (size_t i = 0; i < out_meta.size(); ++i) {
outs.push_back(&result[i]);
outs.push_back(&result->at(i));
}

SplitKernel<T, Context>(dev_ctx, x, sections, axis, outs);
}

template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis) {
size_t out_number = sections.GetData().size();
std::vector<DenseTensor> result(out_number);

Split(dev_ctx, x, sections, axis, &result);

return result;
}

Expand Down
30 changes: 30 additions & 0 deletions test/cpp/auto_parallel/test_reshard_r_to_s.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,36 @@ TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) {
CHECK_EQ(output->dims(), DDim({6, 2}));
}

TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh_unbalance_split) {
setenv("PADDLE_TRAINER_ID", "1", 1);

std::vector<int64_t> tensor_shape = {6, 8};
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* context = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));

std::vector<int64_t> mesh_shape = {4};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
std::vector<std::string> dim_names = {"x"};
ProcessMesh mesh(mesh_shape, process_ids, dim_names);

std::shared_ptr<DistTensor> input =
ConstructReplicatedDistCPU(context, tensor_shape, mesh);

std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);

RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(context, *input, out_dist_attr);

CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 16);
CHECK_EQ(output->dims(), DDim({2, 8}));
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
setenv("PADDLE_TRAINER_ID", "0", 0);
Expand Down

0 comments on commit 8fc9a81

Please sign in to comment.