Skip to content

Commit

Permalink
[pten] add split kernel (#39060)
Browse files Browse the repository at this point in the history
* add split kernel

* add split kernel signature

* fix split bug

* modify MakePtenScalarArrayFromVarList

* modify MakePtenScalarArrayFromVarList

* fix split windows register error

* add test case for split kernel

* replace raw split kernel with pten kernel

* fix makeScalar/ScalarArray bug

* remove debug log

* remove int64_t type in buildPtcontext

* update by code review

* fix split dev test failed

* change DenseTensorMeta to MetaTensor

* change split api code from auto gen to manual

* split cuda kernel support bfloat16 type

* fix conflict

* rm raw split kernel

* merge develop branch

* change to pten::errors
  • Loading branch information
MingMingShangTian authored Feb 14, 2022
1 parent d12c363 commit d0df563
Show file tree
Hide file tree
Showing 24 changed files with 750 additions and 255 deletions.
12 changes: 7 additions & 5 deletions paddle/fluid/framework/custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/extension.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"
Expand Down Expand Up @@ -183,14 +186,14 @@ TEST(CustomKernel, custom_kernel_dot) {
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<uint8_t>(paddle::platform::CPUPlace());

auto dense_y = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_y_data =
dense_y->mutable_data<uint8_t>(paddle::platform::CPUPlace());
Expand Down Expand Up @@ -231,8 +234,7 @@ TEST(CustomKernel, custom_kernel_dot) {
pten::DataType fake_attr_dtype = pten::DataType::UINT32;
paddle::framework::LoDTensor tmp_tensor;
tmp_tensor.mutable_data<uint8_t>({1}, pten::TransToPtenPlace(backend));
pten::Scalar fake_attr_scalar =
paddle::experimental::MakePtenScalar(tmp_tensor);
pten::Scalar fake_attr_scalar{tmp_tensor};
pten::ScalarArray fake_attr_scalar_array;
std::vector<int64_t> fake_attr_int64_vec;
std::vector<int> fake_attr_int_vec;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(int32_t))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
&BOOST_GET_CONST(int32_t, attr_iter->second), 1)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,14 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int32_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/tests/test_prepare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
} // namespace imperative
} // namespace paddle

USE_OP(split);
USE_OP_ITSELF(split);
USE_OP(relu);
#ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
Expand Down
8 changes: 0 additions & 8 deletions paddle/fluid/operators/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,3 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
split, ops::SplitOpKernel<plat::CPUDeviceContext, double>,
ops::SplitOpKernel<plat::CPUDeviceContext, float>,
ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<plat::CPUDeviceContext, int>,
ops::SplitOpKernel<plat::CPUDeviceContext, bool>,
ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>);
25 changes: 0 additions & 25 deletions paddle/fluid/operators/split_op.cu.cc

This file was deleted.

54 changes: 1 addition & 53 deletions paddle/fluid/operators/split_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"

#include "paddle/pten/kernels/split_kernel.h"
namespace paddle {
namespace operators {
static inline std::vector<framework::DDim> UpdateOutsDims(
Expand Down Expand Up @@ -108,56 +106,6 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
}
return outs_dims;
}
template <typename DeviceContext, typename T>
class SplitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
int num = ctx.Attr<int>("num");
std::vector<int> sections = ctx.Attr<std::vector<int>>("sections");
int axis = ctx.Attr<int>("axis");

auto in_dims = in->dims();
auto outs_number = outs.size();

bool need_resize_outs_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor(axis_tensor)[0];
need_resize_outs_dims = true;
}
auto sections_tensor_list =
ctx.MultiInput<framework::Tensor>("SectionsTensorList");
if (sections_tensor_list.size() > 0) {
sections = GetDataFromTensorList(sections_tensor_list);
need_resize_outs_dims = true;
}

if (need_resize_outs_dims) {
std::vector<framework::DDim> outs_dims =
UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number);
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
}

std::vector<const framework::Tensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
shape_refer.emplace_back(outs[j]);
}

auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
StridedMemcpyWithAxis0<T>(dev_ctx, *in, shape_refer, &outs);
} else {
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
}
}
};

template <typename T>
class SplitGradMaker : public framework::SingleGradOpMaker<T> {
Expand Down
8 changes: 8 additions & 0 deletions paddle/pten/api/include/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License. */

#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

/**
* This file stores some special APIs that are implemented manually
Expand All @@ -28,5 +30,11 @@ namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking);

// TODO(chentianyu03): Split API has extra logic to calculate the outputs size,
// api_gen do not support
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);

} // namespace experimental
} // namespace paddle
68 changes: 68 additions & 0 deletions paddle/pten/api/lib/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ limitations under the License. */
#include "glog/logging.h"

#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/infermeta/unary.h"

PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
Expand Down Expand Up @@ -75,6 +78,71 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
return out;
}

PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto dense_x = PrepareData(x, kernel.InputAt(0), {});

// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}

std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<pten::MetaTensor> meta_outs;
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
}

pten::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const pten::DenseTensor&,
const pten::ScalarArray&,
const pten::Scalar&,
std::vector<pten::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*dense_x,
pten::ScalarArray(num_or_sections),
pten::Scalar(axis),
dense_outs);

return out;
}
} // namespace experimental
} // namespace paddle

Expand Down
Loading

0 comments on commit d0df563

Please sign in to comment.