From 4558f80fff21afdf4566ef8970cbb64296008409 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 9 Feb 2022 05:20:53 +0000 Subject: [PATCH 1/4] move grad get expected pten kernel args --- paddle/fluid/framework/operator.cc | 26 ++------------ paddle/fluid/framework/operator.h | 2 +- paddle/fluid/operators/digamma_op.cc | 7 ---- paddle/fluid/operators/dot_op.cc | 7 ---- .../operators/elementwise/elementwise_op.h | 12 ------- paddle/fluid/operators/flatten_op.cc | 6 ---- paddle/fluid/operators/matmul_v2_op.cc | 24 ------------- paddle/fluid/operators/reduce_ops/reduce_op.h | 28 --------------- paddle/fluid/operators/reshape_op.cc | 12 ------- paddle/pten/ops/compat/digamma_sig.cc | 27 +++++++++++++++ paddle/pten/ops/compat/dot_sig.cc | 28 +++++++++++++++ paddle/pten/ops/compat/elementwise_sig.cc | 13 +++++++ paddle/pten/ops/compat/flatten_sig.cc | 8 +++++ paddle/pten/ops/compat/matmul_sig.cc | 34 ++++++++++++++++++- paddle/pten/ops/compat/reshape_sig.cc | 14 ++++++++ 15 files changed, 127 insertions(+), 121 deletions(-) create mode 100644 paddle/pten/ops/compat/digamma_sig.cc create mode 100644 paddle/pten/ops/compat/dot_sig.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0f558b46872a2..5e646ee62e26b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1171,28 +1171,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, std::string pt_kernel_name; if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { - pt_kernel_signature_.reset(new KernelSignature( - std::move(this->GetExpectedPtenKernelArgs(exe_ctx)))); - VLOG(6) << *pt_kernel_signature_.get(); - - kernel_type_.reset( - new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); - dev_ctx = pool.Get(kernel_type_->place_); - - pt_kernel_name = pt_kernel_signature_->name; - pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get()); - pt_kernel_.reset( - new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( - pt_kernel_name, pt_kernel_key))); - - if (pt_kernel_->IsValid()) { - VLOG(6) << "Static mode ChoosePtenKernel - kernel name: " - << pt_kernel_name << " | kernel key: " << pt_kernel_key - << " | kernel: " << *pt_kernel_; - } else { - VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name - << "` not found."; - } + pt_kernel_key = ChoosePtenKernel(exe_ctx); + pt_kernel_name = PtenKernelSignature()->name; } if (pt_kernel_->IsValid()) { run_pten_kernel_ = true; @@ -1360,7 +1340,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( pten::KernelKey OperatorWithKernel::ChoosePtenKernel( const ExecutionContext& ctx) const { pt_kernel_signature_.reset( - new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx)))); + new KernelSignature(std::move(GetExpectedPtenKernelArgs(ctx)))); VLOG(6) << *pt_kernel_signature_.get(); kernel_type_.reset( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index b6600796baf67..2294d67fbf2f3 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -606,7 +606,7 @@ class OperatorWithKernel : public OperatorBase { * When selecting Kernel during Op execution, select the arguments of the * original Op according to the GetExpectedPtenKernelArgs returned arguments. */ - virtual pten::KernelSignature GetExpectedPtenKernelArgs( + pten::KernelSignature GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const; /* member functions for adapting to pten lib */ diff --git a/paddle/fluid/operators/digamma_op.cc b/paddle/fluid/operators/digamma_op.cc index eb0471fec1206..fef2b91b111c5 100644 --- a/paddle/fluid/operators/digamma_op.cc +++ b/paddle/fluid/operators/digamma_op.cc @@ -64,13 +64,6 @@ class DigammaGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("digamma_grad", - {framework::GradVarName("Out"), "X"}, {}, - {framework::GradVarName("X")}); - } }; template diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index e1463c8ccb58e..31acd9718115c 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -117,13 +117,6 @@ class DotGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "dot_grad", {"X", "Y", framework::GradVarName("Out")}, {}, - {framework::GradVarName("X"), framework::GradVarName("Y")}); - } }; template diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index d726bf0d0b5ab..0c04f7b360e30 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -353,18 +353,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - if (Type() == "elementwise_add_grad") { - if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature( - "add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"}, - {framework::GradVarName("X"), framework::GradVarName("Y")}); - } - } - - return framework::KernelSignature("None", {"X"}, {}, {"Out"}); - } }; class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 03ee25accc67d..5f9471cbb3f05 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -421,12 +421,6 @@ class FlattenContiguousRangeGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.device_context()); } - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("flatten_grad", - {framework::GradVarName("Out"), "XShape"}, - {}, {framework::GradVarName("X")}); - } }; DECLARE_INPLACE_OP_INFERER(FlattenOpInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceInferer, diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index a5eca7b225558..5add86f5b3c74 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -389,14 +389,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "matmul_grad", {"X", "Y", framework::GradVarName("Out")}, - {"trans_x", "trans_y"}, - {framework::GradVarName("X"), framework::GradVarName("Y")}); - } }; template @@ -439,13 +431,6 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { context->ShareDim("DOut", "DDOut"); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"}, - {"trans_x", "trans_y"}, {"DX", "DY", "DDOut"}); - } }; template @@ -515,15 +500,6 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { context->ShareDim("Y", "D_DDY_out"); } } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - return framework::KernelSignature( - "matmul_triple_grad", - {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, - {"trans_x", "trans_y"}, - {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); - } }; template diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 4f275717bced8..d7a6da1b33d31 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -546,34 +546,6 @@ class ReduceOp : public framework::OperatorWithKernel { } return framework::OpKernelType(input_data_type, ctx.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext& ctx) const override { - bool reduce_all = ctx.Attr("reduce_all"); - if (Type() == "reduce_sum") { - if (ctx.InputVar("X")->IsType()) { - if (!reduce_all) { - return framework::KernelSignature( - "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); - } - return framework::KernelSignature( - "sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, - {"Out"}); - } - } - if (Type() == "reduce_mean") { - if (ctx.InputVar("X")->IsType()) { - if (!reduce_all) { - return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, - {"Out"}); - } - return framework::KernelSignature( - "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); - } - } - // TODO(chentianyu03): support other cases after selected rows added - return framework::KernelSignature("reduce.unregistered", {}, {}, {}); - } }; class ReduceOpUseInputPlace : public ReduceOp { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 43da63aae73e1..0dfaa8256f6a9 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -579,13 +579,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("reshape_grad", - {framework::GradVarName("Out")}, {}, - {framework::GradVarName("X")}); - } }; class Reshape2DoubleGradOp : public framework::OperatorWithKernel { @@ -622,11 +615,6 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - return framework::KernelSignature("reshape_double_grad", {"DDX"}, {}, - {"DDOut"}); - } }; DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); diff --git a/paddle/pten/ops/compat/digamma_sig.cc b/paddle/pten/ops/compat/digamma_sig.cc new file mode 100644 index 0000000000000..d437133b592fd --- /dev/null +++ b/paddle/pten/ops/compat/digamma_sig.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature DigammaGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "digamma_grad", {GradVarName("Out"), "X"}, {}, {GradVarName("X")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(digamma_grad, pten::DigammaGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/dot_sig.cc b/paddle/pten/ops/compat/dot_sig.cc new file mode 100644 index 0000000000000..5e2b0bd0e543d --- /dev/null +++ b/paddle/pten/ops/compat/dot_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature DotGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("dot_grad", + {"X", "Y", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(dot_grad, pten::DotGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc index 4c14a5d139e3c..fce5899b6a966 100644 --- a/paddle/pten/ops/compat/elementwise_sig.cc +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -64,6 +64,17 @@ KernelSignature ElementwiseDivOpArgumentMapping( return KernelSignature("unregistered", {}, {}, {}); } +KernelSignature ElementwiseAddGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("X")) { + return KernelSignature("add_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); + } + return KernelSignature("unregistered", {}, {}, {}); +} + } // namespace pten PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw); @@ -81,3 +92,5 @@ PT_REGISTER_ARG_MAPPING_FN(elementwise_mul, pten::ElementwiseMulOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(elementwise_div, pten::ElementwiseDivOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(elementwise_add_grad, + pten::ElementwiseAddGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/flatten_sig.cc b/paddle/pten/ops/compat/flatten_sig.cc index 1ef2977bf88d7..711a7a733cefe 100644 --- a/paddle/pten/ops/compat/flatten_sig.cc +++ b/paddle/pten/ops/compat/flatten_sig.cc @@ -28,6 +28,12 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { } } +KernelSignature FlattenGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")}); +} + } // namespace pten PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten); @@ -35,3 +41,5 @@ PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad); PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, pten::FlattenOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range_grad, + pten::FlattenGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/matmul_sig.cc b/paddle/pten/ops/compat/matmul_sig.cc index 67ef91b429e36..963d5d6656b04 100644 --- a/paddle/pten/ops/compat/matmul_sig.cc +++ b/paddle/pten/ops/compat/matmul_sig.cc @@ -14,9 +14,41 @@ limitations under the License. */ #include "paddle/pten/core/compat/op_utils.h" -namespace pten {} // namespace pten +namespace pten { + +KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("matmul_grad", + {"X", "Y", GradVarName("Out")}, + {"trans_x", "trans_y"}, + {GradVarName("X"), GradVarName("Y")}); +} + +KernelSignature MatmulDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("matmul_double_grad", + {"X", "Y", "DOut", "DDX", "DDY"}, + {"trans_x", "trans_y"}, + {"DX", "DY", "DDOut"}); +} + +KernelSignature MatmulTripleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "matmul_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"trans_x", "trans_y"}, + {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); +} + +} // namespace pten PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad); PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad); + +PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad, pten::MatmulGradOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(matmul_v2_grad_grad, + pten::MatmulDoubleGradOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(matmul_v2_triple_grad, + pten::MatmulTripleGradOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc index 031b6875867a5..823fb5d3cdd41 100644 --- a/paddle/pten/ops/compat/reshape_sig.cc +++ b/paddle/pten/ops/compat/reshape_sig.cc @@ -26,6 +26,17 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { } } +KernelSignature ReshapeGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "reshape_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); +} + +KernelSignature ReshapeDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"}); +} + } // namespace pten PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape); @@ -33,3 +44,6 @@ PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad); PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad); PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(reshape2_grad, pten::ReshapeGradOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(reshape2_grad_grad, + pten::ReshapeDoubleGradOpArgumentMapping); From f03247a8866f227db8dd0eed3083c76c4f9ccbda Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 10 Feb 2022 05:18:06 +0000 Subject: [PATCH 2/4] fix reduce sum error --- paddle/pten/ops/compat/reduce_sig.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc index a8a2b517d3e9d..7913673475f96 100644 --- a/paddle/pten/ops/compat/reduce_sig.cc +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -21,7 +21,7 @@ KernelSignature ReduceSumOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("X")) { if (!reduce_all) { return KernelSignature( - "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + "sum", {"X"}, {"dim", "out_dtype", "keep_dim"}, {"Out"}); } return KernelSignature("sum_raw", {"X"}, From b3a32017cfb39c3d13ce66be93ee1a5eac086f23 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 10 Feb 2022 08:45:18 +0000 Subject: [PATCH 3/4] fix element_sub_grad failed --- paddle/pten/ops/compat/elementwise_sig.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc index 0fe8472bbe6e3..c1941f6dde30b 100644 --- a/paddle/pten/ops/compat/elementwise_sig.cc +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -82,7 +82,6 @@ PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract); PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply); PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); PT_REGISTER_ARG_MAPPING_FN(elementwise_add, pten::ElementwiseAddOpArgumentMapping); From 4d5e0d069e37352e0d5383d8bb6c30dfd9621638 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 10 Feb 2022 11:09:38 +0000 Subject: [PATCH 4/4] revert kernel judge change --- paddle/fluid/framework/operator.cc | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f4c20b768a7cc..3993ae842cb32 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1171,8 +1171,28 @@ void OperatorWithKernel::RunImpl(const Scope& scope, std::string pt_kernel_name; if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) { if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) { - pt_kernel_key = ChoosePtenKernel(exe_ctx); - pt_kernel_name = PtenKernelSignature()->name; + pt_kernel_signature_.reset( + new KernelSignature(std::move(GetExpectedPtenKernelArgs(exe_ctx)))); + VLOG(6) << *pt_kernel_signature_.get(); + + kernel_type_.reset( + new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); + dev_ctx = pool.Get(kernel_type_->place_); + + pt_kernel_name = pt_kernel_signature_->name; + pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get()); + pt_kernel_.reset( + new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key))); + + if (pt_kernel_->IsValid()) { + VLOG(6) << "Static mode ChoosePtenKernel - kernel name: " + << pt_kernel_name << " | kernel key: " << pt_kernel_key + << " | kernel: " << *pt_kernel_; + } else { + VLOG(6) << "Static mode ChoosePtenKernel - kernel `" << pt_kernel_name + << "` not found."; + } } if (pt_kernel_->IsValid()) { run_pten_kernel_ = true;