diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 409fb3b713d3b..6077a2b92881f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -782,6 +782,67 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, return true; } +bool EditDistanceOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &hyps_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &refs_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + + const auto &hyps_dims = hyps_shape_or_data.shape(); + const auto &refs_dims = refs_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + hyps_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(Hyps) must be a 2-D Tensor, but received rank %u.", + hyps_dims.size())); + PADDLE_ENFORCE_EQ( + refs_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(Refs) must be a 2-D Tensor, but received rank %u.", + refs_dims.size())); + + infer_context->AddEqualCstr(hyps_dims[0], refs_dims[0]); + + bool has_lengths = op->operand_source(2) && op->operand_source(3); + if (has_lengths) { + const auto &hypslength_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const auto &refslength_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(3)); + + infer_context->AddEqualCstr(hypslength_shape_or_data.shape()[0], + hyps_dims[0]); + + infer_context->AddEqualCstr(refslength_shape_or_data.shape()[0], + refs_dims[0]); + + infer_context->AddEqualCstr(hypslength_shape_or_data.shape()[0], + refslength_shape_or_data.shape()[0]); + } else { + symbol::DimExpr one = symbol::DimExpr(1); + + infer_context->AddEqualCstr(hyps_dims[1], one); + infer_context->AddEqualCstr(refs_dims[1], one); + } + + symbol::ShapeOrDataDimExprs refs_shape_or_data_exprs( + symbol::TensorShapeOrDataDimExprs( + std::vector{refs_dims})); + infer_context->SetShapeOrDataForValue(op->result(0), + refs_shape_or_data_exprs); + + symbol::ShapeOrDataDimExprs single_dim_expr(symbol::TensorShapeOrDataDimExprs( + std::vector{symbol::DimExpr(1)})); + infer_context->SetShapeOrDataForValue( + op->result(1), symbol::ShapeOrDataDimExprs({single_dim_expr})); + + return true; +} + bool FakeQuantizeMovingAverageAbsMaxOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const symbol::ShapeOrDataDimExprs &x_shape = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 650b374b8cec1..7a88120d00bb3 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -41,6 +41,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FakeQuantizeMovingAverageAbsMax_) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CoalesceTensor) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CoalesceTensor_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(EditDistance) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttnUnpadded) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index cca48987ca11d..07d108c6b4ad4 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1467,6 +1467,7 @@ func : edit_distance data_type : DataType::FLOAT32 optional : hypslength, refslength + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : eig args: (Tensor x)