Skip to content

Commit

Permalink
support non-zero shard to replicated
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Aug 23, 2023
1 parent 48afda1 commit 227a9ea
Show file tree
Hide file tree
Showing 17 changed files with 215 additions and 63 deletions.
21 changes: 21 additions & 0 deletions paddle/fluid/pybind/eager_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,22 @@ PyObject* tensor_properties_get_dist_attr(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyObject* tensor_properties_get_local_shape(TensorObject* self, void* closure) {
EAGER_TRY
if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
return ToPyObject(phi::vectorize<int64_t>(dist_tensor->local_dims()));
#else
RETURN_PY_NONE
#endif
} else {
RETURN_PY_NONE
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyDoc_STRVAR(tensor_shape__doc__,
R"DOC(shape
Expand Down Expand Up @@ -716,6 +732,11 @@ struct PyGetSetDef variable_properties[] = { // NOLINT
(setter)tensor_properties_set_persistable,
tensor_persistable__doc__,
nullptr},
{"_local_shape",
(getter)tensor_properties_get_local_shape,
nullptr,
nullptr,
nullptr},
{"shape",
(getter)tensor_properties_get_shape,
nullptr,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ if(WITH_DISTRIBUTE)
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_concat_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ DistTensor::DistTensor(const phi::DenseTensor& global_value,
// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
DistTensor out = func->Eval(dev_ctx, *this, dist_attr);
value_ = out.value();
auto out = func->Eval(dev_ctx, *this, dist_attr);

// 3. reset dist attr and value
dist_attr_.set_dims_mapping(dist_attr.dims_mapping());
value_ = out->value();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ bool RToSReshardFunction::IsSuitable(const DistTensor& in,
return flag;
}

DistTensor RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
phi::DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr.dims_mapping();
const auto& out_process_mesh = out_dist_attr.process_mesh();
const DenseTensor& in_physical_tensor_cur_rank = in.value();
Expand All @@ -62,14 +63,6 @@ DistTensor RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t mesh_axis = split_axis_to_mesh_axis.begin()->second;

PADDLE_ENFORCE_LT(
mesh_axis,
out_process_mesh.ndim(),
phi::errors::OutOfRange(
"The mesh axis %lld exceed the size of process mesh %lld.",
mesh_axis,
out_process_mesh.ndim()));

int64_t num_of_process = out_process_mesh.shape()[mesh_axis];
VLOG(3) << "RToSReshard: Tensor will be split on axis " << split_axis
<< ". Split will use axis " << mesh_axis << " of process_mesh."
Expand All @@ -89,7 +82,8 @@ DistTensor RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
VLOG(3) << "The shape of physical tensor after split is "
<< out_physical_tensor_cur_rank.dims();

return DistTensor(out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
return std::make_shared<DistTensor>(
out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
}

REGISTER_RESHARD_FUNC(RToSReshardFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class RToSReshardFunction final : public ReshardFunction {
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

DistTensor Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/concat_kernel.h"

namespace phi {
namespace distributed {

DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis) {
DenseTensor result;
auto dtype = (*input.begin())->dtype();

if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const CPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const GPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The concat in reshard only supported on CPU and GPU for now."));
}

} // namespace distributed
} // namespace phi
30 changes: 30 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
#include <vector>

namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {

DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis);

} // namespace distributed
} // namespace phi
27 changes: 11 additions & 16 deletions paddle/phi/core/distributed/auto_parallel/reshard_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,21 @@ class ReshardFunction {
virtual bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0;

virtual DistTensor Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) = 0;
};

std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList();

template <typename RESHARD_FUNC>
std::unique_ptr<RESHARD_FUNC> CreateReshardFunction() {
return std::make_unique<RESHARD_FUNC>();
}

#define REGISTER_RESHARD_FUNC(func_type) \
class __RegisterReshard_##func_type { \
public: \
__RegisterReshard_##func_type() { \
GetReshardFunctionList().emplace_back( \
CreateReshardFunction<func_type>()); \
} \
}; \
#define REGISTER_RESHARD_FUNC(func_type) \
class __RegisterReshard_##func_type { \
public: \
__RegisterReshard_##func_type() { \
GetReshardFunctionList().emplace_back(std::make_unique<func_type>()); \
} \
}; \
static __RegisterReshard_##func_type local_reshard_func_##func_type

ReshardFunction* ChooseProperReshardFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {
Expand All @@ -43,17 +44,25 @@ bool SToRReshardFunction::IsSuitable(const DistTensor& in,
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);

// Ensure the tensor is balanced split, or we need send/recv rather than
// all_gather
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;
int64_t num_of_process = in_process_mesh.size();
flag &=
(in.local_dims()[split_axis] * num_of_process == in.dims()[split_axis]);

return flag;
}

DistTensor SToRReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();

Expand All @@ -63,7 +72,39 @@ DistTensor SToRReshardFunction::Eval(DeviceContext* dev_ctx,
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);

return DistTensor(out_all_gather, out_all_gather.dims(), out_dist_attr);
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;

if (split_axis == 0) {
// If the input dist tensor is shard(0), the subsequent split
// and concat is unnecessary.
return std::make_shared<DistTensor>(
out_all_gather, out_all_gather.dims(), out_dist_attr);
} else {
// Since the result of all_gather always concat the tensor on axis 0,
// first we need to split the result on axis 0,
// then we need to concat the split result on input split axis.
int64_t default_split_axis = 0;
int64_t num_of_process = in_process_ids.size();

IntArray sections(std::vector<int64_t>(
num_of_process,
in_physical_tensor_cur_rank.dims()[default_split_axis]));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, out_all_gather, sections, default_split_axis);

// Concat the result after split on correct axis.
std::vector<const DenseTensor*> concat_input_vec;
for (const auto& tensor : split_out_vec) {
concat_input_vec.emplace_back(&tensor);
}
DenseTensor concat_out_tensor =
ReshardConcatFunctor(*dev_ctx, concat_input_vec, split_axis);

return std::make_shared<DistTensor>(
concat_out_tensor, concat_out_tensor.dims(), out_dist_attr);
}
}

REGISTER_RESHARD_FUNC(SToRReshardFunction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class SToRReshardFunction final : public ReshardFunction {
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

DistTensor Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;
};

} // namespace distributed
Expand Down
20 changes: 14 additions & 6 deletions paddle/phi/kernels/concat_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ void ConcatKernel(const Context& dev_ctx,
DenseTensor* out);

template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis) {
void Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis,
DenseTensor* dense_out) {
std::vector<MetaTensor> meta_x;
meta_x.reserve(x.size());
std::vector<const MetaTensor*> meta_x_ptr;
Expand All @@ -38,10 +39,17 @@ DenseTensor Concat(const Context& dev_ctx,
meta_x_ptr.push_back(&meta_x.back());
}

DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
MetaTensor meta_out(dense_out);
ConcatInferMeta(meta_x_ptr, axis.to<int>(), &meta_out);
ConcatKernel<T, Context>(dev_ctx, x, axis, &dense_out);
ConcatKernel<T, Context>(dev_ctx, x, axis, dense_out);
}

template <typename T, typename Context>
DenseTensor Concat(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const Scalar& axis) {
DenseTensor dense_out;
Concat<T, Context>(dev_ctx, x, axis, &dense_out);
return dense_out;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/concat_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ PD_REGISTER_KERNEL(concat,
int,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/concat_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ PD_REGISTER_KERNEL(concat,
int,
uint8_t,
int8_t,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/split_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ std::vector<DenseTensor> Split(const Context& dev_ctx,
size_t out_number = sections.GetData().size();
std::vector<DenseTensor> result(out_number);

Split(dev_ctx, x, sections, axis, &result);
Split<T, Context>(dev_ctx, x, sections, axis, &result);

return result;
}
Expand Down
Loading

0 comments on commit 227a9ea

Please sign in to comment.