diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 44fb34ba1a5e6..8a44b8bade20f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -660,6 +660,17 @@ bool NonzeroOpInferSymbolicShape( return true; } +bool NumelOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + std::vector out_shape = {}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} + bool PadOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { // input(0): Tensor x diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index bfb00bbef9b90..2488b579929b2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -51,6 +51,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maxout) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a3a602a1169c3..983635dfb09a8 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3348,6 +3348,7 @@ data_transform: skip_transform : x no_need_buffer : x + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : one_hot args : (Tensor x, Scalar(int) num_classes)