From f54b5df7fea49dfb597043be2b78287d57392b2c Mon Sep 17 00:00:00 2001 From: zhoujianqian <15205085056@163.com> Date: Wed, 2 Mar 2022 11:21:03 +0000 Subject: [PATCH] move gather_tree infer shape --- paddle/fluid/operators/gather_tree_op.cc | 23 ++++++++--------------- paddle/phi/infermeta/binary.cc | 13 +++++++++++++ paddle/phi/infermeta/binary.h | 4 ++++ 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc index 2868c3697eda1..7f6c82032fe39 100644 --- a/paddle/fluid/operators/gather_tree_op.cc +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -21,20 +24,6 @@ class GatherTreeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "GatherTree"); - OP_INOUT_CHECK(ctx->HasInput("Parents"), "Input", "Parents", "GatherTree"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GatherTree"); - - auto ids_dims = ctx->GetInputDim("Ids"); - auto parents_dims = ctx->GetInputDim("Parents"); - PADDLE_ENFORCE_EQ(ids_dims == parents_dims, true, - platform::errors::InvalidArgument( - "The shape of Input(Parents) must be same with the " - "shape of Input(Ids).")); - ctx->SetOutputDim("Out", ids_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -72,4 +61,8 @@ selected ids. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker); +DELCARE_INFER_SHAPE_FUNCTOR(gather_tree, GatherTreeInferShapeFunctor, + PT_INFER_META(phi::GatherTreeMeta)); + +REGISTER_OPERATOR(gather_tree, ops::GatherTreeOp, ops::GatherTreeOpMaker, + GatherTreeInferShapeFunctor); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 675e68af74339..7682f6b3d49b9 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -348,4 +348,17 @@ void BCELossInferMeta(const MetaTensor& input, out->share_lod(input); } +void GatherTreeMeta(const MetaTensor& ids, + const MetaTensor& parents, + MetaTensor* out) { + auto ids_dims = ids.dims(); + auto parents_dims = parents.dims(); + PADDLE_ENFORCE_EQ(ids_dims == parents_dims, + true, + phi::errors::InvalidArgument( + "The shape of Input(Parents) must be same with the " + "shape of Input(Ids).")); + out->set_dims(ids_dims); +} + } // namespace phi diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index a0140c9a5799f..5906e06b29355 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -68,4 +68,8 @@ void BCELossInferMeta(const MetaTensor& input, const MetaTensor& label, MetaTensor* out, MetaConfig config = MetaConfig()); + +void GatherTreeMeta(const MetaTensor& ids, + const MetaTensor& parents, + MetaTensor* out); } // namespace phi