diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 79512fb69ea95..f1dc55b5a1faf 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -599,6 +599,17 @@ bool NonzeroOpInferSymbolicShape( return true; } +bool NumelOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + std::vector out_shape = {}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} + bool PadOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // input(0): Tensor x diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index cf3e4926c964c..32f85b3cadc90 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -47,6 +47,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 769c5bb3a61df..f5bc89932e5f7 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3342,6 +3342,7 @@ data_transform: skip_transform : x no_need_buffer : x + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : one_hot args : (Tensor x, Scalar(int) num_classes)