From 53d19f5b1e985f288cdf8b963ab05b9a06c546c3 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Thu, 8 Mar 2018 22:30:46 +0800 Subject: [PATCH] Add ElementwiseOpInferVarType --- paddle/fluid/operators/elementwise_add_op.cc | 7 +++++-- paddle/fluid/operators/elementwise_op.h | 10 ++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise_add_op.cc index e9068fcd50ba9..4aab54f60236e 100644 --- a/paddle/fluid/operators/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise_add_op.cc @@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker, - elementwise_add_grad, ops::ElementwiseOpGrad); +REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp, + ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad); + REGISTER_OP_CPU_KERNEL( elementwise_add, ops::ElementwiseAddKernel, diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index fe31bbaed44fc..f04d8d8fd82ed 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel { } }; +class ElementwiseOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto x_var = op_desc.Input("X")[0]; + auto out_var = op_desc.Output("Out")[0]; + block->Var(out_var)->SetType(block->Var(x_var)->GetType()); + } +}; + class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { public: ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)