From b49394624e086b45bdb7d8b18d8da35e7de2fea7 Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Fri, 26 Jul 2024 15:20:59 +0800 Subject: [PATCH 1/2] add numel --- .../infer_symbolic_shape/unary_infer_sym.cc | 14 ++++++++++++++ .../infer_symbolic_shape/unary_infer_sym.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 16 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 79512fb69ea95..9f8ae78761f43 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -599,6 +599,20 @@ bool NonzeroOpInferSymbolicShape( return true; } +bool NumelOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + + std::vector out_shape = {}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} + bool PadOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // input(0): Tensor x diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index cf3e4926c964c..32f85b3cadc90 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -47,6 +47,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 769c5bb3a61df..f5bc89932e5f7 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3342,6 +3342,7 @@ data_transform: skip_transform : x no_need_buffer : x + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : one_hot args : (Tensor x, Scalar(int) num_classes) From 6e26a069f8de7a8d40877e8b8ba39d5a28b0620a Mon Sep 17 00:00:00 2001 From: uanu2002 Date: Tue, 30 Jul 2024 17:37:58 +0800 Subject: [PATCH 2/2] remove unused var --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 9f8ae78761f43..f1dc55b5a1faf 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -601,9 +601,6 @@ bool NonzeroOpInferSymbolicShape( bool NumelOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const auto &input_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(0)); - std::vector out_shape = {}; infer_context->SetShapeOrDataForValue( op->result(0),