Skip to content

Commit

Permalink
[Pten] Support optional param for C++ API (#39760)
Browse files Browse the repository at this point in the history
* fix selected_rows bug in C++ API

* add optional for C++ APIO

* data transform support optional

* remove data transform for optional vector<Tensor>

* adjust some format of funtcion

* fix empyt bug
  • Loading branch information
zyfncg authored Feb 28, 2022
1 parent bd9b946 commit aceb25e
Show file tree
Hide file tree
Showing 17 changed files with 240 additions and 51 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace experimental {
Tensor copy_to_impl(const Tensor& x, Backend backend, bool blocking) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"copy", kernel_key);

Expand Down Expand Up @@ -67,7 +67,7 @@ std::vector<Tensor> split_impl(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();

Backend kernel_backend = kernel_key.backend();
DataLayout kernel_layout = kernel_key.layout();
Expand Down
32 changes: 32 additions & 0 deletions paddle/phi/api/lib/api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ inline std::shared_ptr<phi::DenseTensor> TensorToDenseTensor(
return std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl());
}

inline std::shared_ptr<phi::DenseTensor> TensorToDenseTensor(
const paddle::optional<Tensor>& tensor) {
if (tensor) {
return std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl());
}
return nullptr;
}

inline std::unique_ptr<std::vector<phi::DenseTensor>> TensorToDenseTensor(
const std::vector<Tensor>& tensors) {
auto pt_tensors = std::make_unique<std::vector<phi::DenseTensor>>();
Expand All @@ -49,12 +57,28 @@ inline std::shared_ptr<phi::SelectedRows> TensorToSelectedRows(
return std::dynamic_pointer_cast<phi::SelectedRows>(tensor.impl());
}

inline std::shared_ptr<phi::SelectedRows> TensorToSelectedRows(
const paddle::optional<Tensor>& tensor) {
if (tensor) {
return std::dynamic_pointer_cast<phi::SelectedRows>(tensor->impl());
}
return nullptr;
}

/* ----------------- for infer_meta --------------------- */

inline phi::MetaTensor MakeMetaTensor(const phi::DenseTensor& tensor) {
return phi::MetaTensor(tensor);
}

inline paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::DenseTensor&>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return {paddle::none};
}

inline std::vector<phi::MetaTensor> MakeMetaTensor(
const std::vector<phi::DenseTensor>& tensors) {
std::vector<phi::MetaTensor> meta_tensors;
Expand All @@ -69,6 +93,14 @@ inline phi::MetaTensor MakeMetaTensor(const phi::SelectedRows& tensor) {
return phi::MetaTensor(tensor);
}

inline paddle::optional<phi::MetaTensor> MakeMetaTensor(
const paddle::optional<const phi::SelectedRows&>& tensor) {
if (tensor) {
return {phi::MetaTensor(*tensor)};
}
return {paddle::none};
}

/* ------------------ for output ----------------------- */

inline phi::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,16 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
return std::make_shared<phi::DenseTensor>(out);
}

std::shared_ptr<phi::DenseTensor> PrepareData(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag) {
if (input) {
return PrepareData(*input, target_args_def, transform_flag);
}
return {nullptr};
}

std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const phi::TensorArgDef& target_args_def,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ std::shared_ptr<phi::DenseTensor> PrepareData(
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);

std::shared_ptr<phi::DenseTensor> PrepareData(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag);

std::unique_ptr<std::vector<phi::DenseTensor>> PrepareData(
const std::vector<Tensor>& inputs,
const phi::TensorArgDef& target_args_def,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct KernelKeySet {
DataType dtype{DataType::UNDEFINED};

// TODO(chenweihang): iterate all kernelkey for kernel selection
phi::KernelKey GetHigestPriorityKernelKey() {
phi::KernelKey GetHighestPriorityKernelKey() {
return phi::KernelKey(static_cast<Backend>(64 - detail::CountLeadingZeros(
backend_set.bitset())),
layout,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/lib/sparse_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ PADDLE_API Tensor to_sparse_coo(const Tensor& x,
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_coo";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_coo";
Expand Down Expand Up @@ -112,7 +112,7 @@ PADDLE_API Tensor to_sparse_csr(const Tensor& x, Backend backend) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_csr";
if (x.layout() == phi::DataLayout::SPARSE_COO) {
kernel_name = "sparse_coo_to_csr";
Expand Down Expand Up @@ -179,7 +179,7 @@ PADDLE_API Tensor to_dense(const Tensor& x, Backend backend) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "sparse_coo_to_dense";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_dense";
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x,
}
}

void GeneralTernaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
}

void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
int axis,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ void GeneralBinaryGradInferMeta(const MetaTensor& x,
MetaTensor* dx,
MetaTensor* dy);

void GeneralTernaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz);

void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
const MetaTensor& dout,
int axis,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/empty_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape,
DataType dtype,
DenseTensor* out) {
out->ResizeAndAllocate(phi::make_ddim(shape.GetData()));
out->Resize(phi::make_ddim(shape.GetData()));
dev_ctx.template Alloc<T>(out);
}

template <typename T, typename Context>
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/impl/matmul_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,6 @@ void MatmulDoubleGradKernel(const Context& dev_ctx,
ddout_flag = true;
}
}

if (ddy) {
auto ddy_mat = ddy.get();
if (ddy_mat.dims() != y_help.dims()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/tests/api/scale_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ PADDLE_API Tensor scale_kernel_context(const Tensor& x,
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
Expand Down Expand Up @@ -215,7 +215,7 @@ Tensor scale_switch_case(const Tensor& x,
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/tests/api/test_matmul_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>

#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h"

#include "paddle/phi/api/lib/utils/allocator.h"
Expand Down Expand Up @@ -161,5 +162,31 @@ TEST(API, matmul_cuda) {

#endif

TEST(API, matmul_double_grad) {
// 1. create tensor
auto x = paddle::experimental::full({3, 3}, 1.0);
auto y = paddle::experimental::full({3, 3}, 2.0);
auto out_grad = paddle::experimental::full({3, 3}, 2.0);
auto dx_grad = paddle::experimental::full({3, 3}, 2.0);

// 2. test API
const auto out = paddle::experimental::matmul_double_grad(
x, y, out_grad, dx_grad, {}, false, false);

// 3. check result
ASSERT_EQ(out.size(), 3UL);
ASSERT_EQ(out[0].size(), 1UL);
ASSERT_EQ(out[1].size(), 1UL);
ASSERT_EQ(out[2].size(), 1UL);
ASSERT_EQ(out[0][0].dims()[1], 3);
ASSERT_EQ(out[0][0].numel(), 9);
ASSERT_EQ(out[1][0].numel(), 9);
ASSERT_EQ(out[2][0].numel(), 9);
ASSERT_EQ(out[0][0].type(), phi::DataType::FLOAT32);
ASSERT_EQ(out[0][0].layout(), phi::DataLayout::NCHW);
ASSERT_EQ(out[1][0].initialized(), true);
ASSERT_EQ(out[2][0].initialized(), true);
}

} // namespace tests
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/utils/optional.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#pragma once

#include <algorithm>
#include <cassert>
#include <functional>
#include <new>
#include <type_traits>
Expand Down
Loading

0 comments on commit aceb25e

Please sign in to comment.