diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5ed86fb6b717f..3993ae842cb32 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1171,8 +1171,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, std::string pt_kernel_name; if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { - pt_kernel_signature_.reset(new KernelSignature( - std::move(this->GetExpectedPtenKernelArgs(exe_ctx)))); + pt_kernel_signature_.reset( + new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx)))); VLOG(6) << *pt_kernel_signature_.get(); kernel_type_.reset( @@ -1359,7 +1359,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( pten::KernelKey OperatorWithKernel::ChoosePtenKernel( const ExecutionContext& ctx) const { pt_kernel_signature_.reset( - new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx)))); + new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx)))); VLOG(6) << *pt_kernel_signature_.get(); kernel_type_.reset( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index b6600796baf67..2294d67fbf2f3 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -606,7 +606,7 @@ class OperatorWithKernel : public OperatorBase { * When selecting Kernel during Op execution, select the arguments of the * original Op according to the GetExpectedPtenKernelArgs returned arguments. */ - virtual pten::KernelSignature GetExpectedPtenKernelArgs( + pten::KernelSignature GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const; /* member functions for adapting to pten lib */ diff --git a/paddle/fluid/operators/digamma_op.cc b/paddle/fluid/operators/digamma_op.cc index eb0471fec1206..fef2b91b111c5 100644 --- a/paddle/fluid/operators/digamma_op.cc +++ b/paddle/fluid/operators/digamma_op.cc @@ -64,13 +64,6 @@ class DigammaGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("digamma_grad", - {framework::GradVarName("Out"), "X"}, {}, - {framework::GradVarName("X")}); - } }; template diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index e1463c8ccb58e..31acd9718115c 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -117,13 +117,6 @@ class DotGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "dot_grad", {"X", "Y", framework::GradVarName("Out")}, {}, - {framework::GradVarName("X"), framework::GradVarName("Y")}); - } }; template diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index d726bf0d0b5ab..0c04f7b360e30 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -353,18 +353,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - if (Type() == "elementwise_add_grad") { - if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature( - "add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"}, - {framework::GradVarName("X"), framework::GradVarName("Y")}); - } - } - - return framework::KernelSignature("None", {"X"}, {}, {"Out"}); - } }; class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 03ee25accc67d..5f9471cbb3f05 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -421,12 +421,6 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.device_context()); } - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("flatten_grad", - {framework::GradVarName("Out"), "XShape"}, - {}, {framework::GradVarName("X")}); - } }; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index a5eca7b225558..5add86f5b3c74 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -389,14 +389,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "matmul_grad", {"X", "Y", framework::GradVarName("Out")}, - {"trans_x", "trans_y"}, - {framework::GradVarName("X"), framework::GradVarName("Y")}); - } }; template @@ -439,13 +431,6 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { context->ShareDim("DOut", "DDOut"); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"}, - {"trans_x", "trans_y"}, {"DX", "DY", "DDOut"}); - } }; template @@ -515,15 +500,6 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { context->ShareDim("Y", "D_DDY_out"); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "matmul_triple_grad", - {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, - {"trans_x", "trans_y"}, - {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); - } }; template diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 3b8ea60963d62..667ffabbf4044 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -547,34 +547,6 @@ class ReduceOp : public framework::OperatorWithKernel { } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - bool reduce_all = ctx.Attr("reduce_all"); - if (Type() == "reduce_sum") { - if (ctx.InputVar("X")->IsType()) { - if (!reduce_all) { - return framework::KernelSignature( - "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); - } - return framework::KernelSignature( - "sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, - {"Out"}); - } - } - if (Type() == "reduce_mean") { - if (ctx.InputVar("X")->IsType()) { - if (!reduce_all) { - return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, - {"Out"}); - } - return framework::KernelSignature( - "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); - } - } - // TODO(chentianyu03): support other cases after selected rows added - return framework::KernelSignature("reduce.unregistered", {}, {}, {}); - } }; class ReduceOpUseInputPlace : public ReduceOp { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 1ef90ff2b732e..77c4a2005e3bf 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -579,13 +579,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("reshape_grad", - {framework::GradVarName("Out")}, {}, - {framework::GradVarName("X")}); - } }; class Reshape2DoubleGradOp : public framework::OperatorWithKernel { @@ -622,11 +615,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("reshape_double_grad", {"DDX"}, {}, - {"DDOut"}); - } }; DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); diff --git a/paddle/pten/ops/compat/digamma_sig.cc b/paddle/pten/ops/compat/digamma_sig.cc new file mode 100644 index 0000000000000..d437133b592fd --- /dev/null +++ b/paddle/pten/ops/compat/digamma_sig.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature DigammaGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "digamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(digamma_grad, pten::DigammaGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/dot_sig.cc b/paddle/pten/ops/compat/dot_sig.cc new file mode 100644 index 0000000000000..5e2b0bd0e543d --- /dev/null +++ b/paddle/pten/ops/compat/dot_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("dot_grad", + {"X", "Y", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(dot_grad, pten::DotGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc index 57bd03f8a21d5..c1941f6dde30b 100644 --- a/paddle/pten/ops/compat/elementwise_sig.cc +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -64,6 +64,17 @@ KernelSignature ElementwiseDivOpArgumentMapping( return KernelSignature("unregistered", {}, {}, {}); } +KernelSignature ElementwiseAddGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("X")) { + return KernelSignature("add_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + } // namespace pten PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); @@ -71,7 +82,6 @@ PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract); PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply); PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); PT_REGISTER_ARG_MAPPING_FN(elementwise_add, pten::ElementwiseAddOpArgumentMapping); @@ -81,3 +91,5 @@ PT_REGISTER_ARG_MAPPING_FN(elementwise_mul, pten::ElementwiseMulOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(elementwise_div, pten::ElementwiseDivOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_add_grad, + pten::ElementwiseAddGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/flatten_sig.cc b/paddle/pten/ops/compat/flatten_sig.cc index 1ef2977bf88d7..711a7a733cefe 100644 --- a/paddle/pten/ops/compat/flatten_sig.cc +++ b/paddle/pten/ops/compat/flatten_sig.cc @@ -28,6 +28,12 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { } } +KernelSignature FlattenGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")}); +} + } // namespace pten PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten); @@ -35,3 +41,5 @@ PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad); PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, pten::FlattenOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range_grad, + pten::FlattenGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/matmul_sig.cc b/paddle/pten/ops/compat/matmul_sig.cc index 67ef91b429e36..963d5d6656b04 100644 --- a/paddle/pten/ops/compat/matmul_sig.cc +++ b/paddle/pten/ops/compat/matmul_sig.cc @@ -14,9 +14,41 @@ limitations under the License. */ #include "paddle/pten/core/compat/op_utils.h" -namespace pten {} // namespace pten +namespace pten { + +KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("matmul_grad", + {"X", "Y", GradVarName("Out")}, + {"trans_x", "trans_y"}, + {GradVarName("X"), GradVarName("Y")}); +} + +KernelSignature MatmulDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("matmul_double_grad", + {"X", "Y", "DOut", "DDX", "DDY"}, + {"trans_x", "trans_y"}, + {"DX", "DY", "DDOut"}); +} + +KernelSignature MatmulTripleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "matmul_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"trans_x", "trans_y"}, + {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); +} + +} // namespace pten PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad); + +PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad, pten::MatmulGradOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad_grad, + pten::MatmulDoubleGradOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(matmul_v2_triple_grad, + pten::MatmulTripleGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc index 10f73d8122e4e..f07b05cec40a2 100644 --- a/paddle/pten/ops/compat/reduce_sig.cc +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("X")) { if (!reduce_all) { return KernelSignature( - "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); } return KernelSignature("sum_raw", {"X"}, diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc index 031b6875867a5..823fb5d3cdd41 100644 --- a/paddle/pten/ops/compat/reshape_sig.cc +++ b/paddle/pten/ops/compat/reshape_sig.cc @@ -26,6 +26,17 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { } } +KernelSignature ReshapeGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "reshape_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); +} + +KernelSignature ReshapeDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"}); +} + } // namespace pten PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape); @@ -33,3 +44,6 @@ PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad); PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad); PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(reshape2_grad, pten::ReshapeGradOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(reshape2_grad_grad, + pten::ReshapeDoubleGradOpArgumentMapping);