Skip to content

Commit

Permalink
Add ElementwiseOpInferVarType
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Mar 8, 2018
1 parent 0d49b92 commit 53d19f5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
7 changes: 5 additions & 2 deletions paddle/fluid/operators/elementwise_add_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);

REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel {
}
};

class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_var = op_desc.Input("X")[0];
auto out_var = op_desc.Output("Out")[0];
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
}
};

class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
Expand Down

0 comments on commit 53d19f5

Please sign in to comment.