Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel]: Support std::vector<phi::Tensor> input and output for DistTensor. #56602

Merged
merged 10 commits into from
Sep 5, 2023
Merged
15 changes: 15 additions & 0 deletions paddle/fluid/eager/tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
#ifndef PADDLE_NO_PYTHON
#include "paddle/fluid/eager/hooks.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif

namespace egr {
class TensorWrapper {
Expand Down Expand Up @@ -66,6 +70,17 @@ class TensorWrapper {
intermidiate_tensor_.set_impl(std::make_shared<phi::DenseTensor>(
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
dense_tensor->meta()));
#ifdef PADDLE_WITH_DISTRIBUTE
} else if (phi::distributed::DistTensor::classof(tensor.impl().get())) {
// Only Copy Meta
phi::distributed::DistTensor* dist_tensor = static_cast<phi::distributed::DistTensor*>(tensor.impl().get());
// TODO(jiabin): It's not a good idea to set memory size to zero, find
// another way and change this.
intermidiate_tensor_.set_impl(std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::Allocation>(nullptr, 0, tensor.place()),
GhostScreaming marked this conversation as resolved.
Show resolved Hide resolved
dist_tensor->meta(),
dist_tensor->dist_attr()));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized tensor type for no_need_buffer feature"));
Expand Down
34 changes: 34 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,19 @@ std::vector<phi::MetaTensor> MakeMetaTensor(
return meta_tensors;
}

// #ifdef PADDLE_WITH_DISTRIBUTE
// /* ------------------ for auto parallel ----------------------- */
// std::vector<phi::MetaTensor> MakeMetaTensor(
// const std::vector<const phi::distributed::DistTensor*>& tensors) {
// std::vector<phi::MetaTensor> meta_tensors;
// meta_tensors.reserve(tensors.size());
// for (const auto* t : tensors) {
// meta_tensors.emplace_back(*(t->impl()));
// }
// return meta_tensors;
// }
// #endif

/* ------------------ for output ----------------------- */

phi::DenseTensor* SetKernelOutput(Tensor* out) {
Expand Down Expand Up @@ -550,6 +563,27 @@ phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out) {
}
return nullptr;
}

std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(std::vector<Tensor*> out) {
std::vector<phi::distributed::DistTensor*> result;
for (auto tmp : out) {
if (tmp) {
// TODO(chenweihang): now all dist case are nullptr
if (tmp->impl() == nullptr) {
auto dense_t = std::make_shared<phi::DenseTensor>();
// TODO(chenweihang): polish code, dist_attr is null now
auto dist_attr = std::make_shared<phi::distributed::TensorDistAttr>();
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
dense_t, phi::DenseTensorMeta(), dist_attr);
tmp->set_impl(dist_t);
}
result.emplace_back(static_cast<phi::distributed::DistTensor*>(tmp->impl().get()));
} else {
result.emplace_back(nullptr);
}
}
return result;
}
#endif

} // namespace experimental
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx,
/* ------------------ for auto parallel ----------------------- */

phi::distributed::DistTensor* SetKernelDistOutput(Tensor* out);
std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(std::vector<Tensor*> out);
#endif

} // namespace experimental
Expand Down
43 changes: 43 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,49 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
}
return nullptr;
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>> PrepareDataForDistTensor(
const std::vector<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> out;
for (auto x : input) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以改成auto&,减少拷贝

const auto& tensor_in = x.impl();
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
phi::DenseTensor& dense_tensor = *(dist_tensor->mutable_value());
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
dense_tensor.place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(dense_tensor.layout(),
target_args_def.layout,
dense_tensor.place(),
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
out.push_back(std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上这里可以写到else分支里,不用continue

}
phi::DenseTensor trans_in_tensor = TransformData(
&dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(std::move(trans_in_tensor)),
dist_tensor->meta(),
dist_tensor->dist_attr()));
} else {
out.push_back(nullptr);
}
}
return out;
}
#endif

} // namespace experimental
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ std::shared_ptr<phi::distributed::DistTensor> PrepareDataForDistTensor(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);

std::vector<std::shared_ptr<phi::distributed::DistTensor>> PrepareDataForDistTensor(
const std::vector<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);
#endif

} // namespace experimental
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
}
}

void operator()(const paddle::optional<std::vector<Tensor>>& x) {
if (x) {
if (!(x.get_ptr()->empty())) {
for (auto& t : *(x.get_ptr())) {
result &= t.is_dist_tensor();
}
}
}
}

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
Expand Down
Loading