From 34b9557e4d5a1589793e16ee19b842f6da373525 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Wed, 31 Jul 2024 16:49:11 +0800 Subject: [PATCH 1/5] Fixed cinn edit_distance --- .../multiary_infer_sym.cc | 69 +++++++++++++++++++ .../infer_symbolic_shape/multiary_infer_sym.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 71 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index e399ead76e7e8..0ecfd59e11b99 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -637,6 +637,75 @@ bool ConcatOpInferSymbolicShape(pir::Operation *op, return true; } +bool EditDistanceOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &hyps_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &refs_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + + const auto &hyps_dims = hyps_shape_or_data.shape(); + const auto &refs_dims = refs_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + hyps_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(Hyps) must be a 2-D Tensor, but received rank %u.", + hyps_dims.size())); + PADDLE_ENFORCE_EQ( + refs_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(Refs) must be a 2-D Tensor, but received rank %u.", + refs_dims.size())); + PADDLE_ENFORCE_EQ( + hyps_dims[0], + refs_dims[0], + phi::errors::InvalidArgument("The first dimension of Input(Hyps) and " + "Input(Refs) must be the same, " + "but received: %d vs %d.", + hyps_dims[0], + refs_dims[0])); + + bool has_lengths = op->operand_source(2) && op->operand_source(3); + if (has_lengths) { + const auto &hypslength_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const auto &refslength_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(3)); + + PADDLE_ENFORCE_EQ( + hypslength_shape_or_data.shape()[0], + hyps_dims[0], + phi::errors::InvalidArgument("Input(HypsLength) and Input(Hyps) should " + "have identical first dimension.")); + PADDLE_ENFORCE_EQ( + refslength_shape_or_data.shape()[0], + refs_dims[0], + phi::errors::InvalidArgument("Input(RefsLength) and Input(Refs) should " + "have identical first dimension.")); + } else { + PADDLE_ENFORCE_EQ( + hyps_dims[1], + 1, + phi::errors::InvalidArgument("Input(Hyps) must be a 2-D Tensor with " + "the 2nd dimension equal to 1.")); + PADDLE_ENFORCE_EQ( + refs_dims[1], + 1, + phi::errors::InvalidArgument("Input(Refs) must be a 2-D Tensor with " + "the 2nd dimension equal to 1.")); + } + + infer_context->SetShapeOrDataForValue(op->result(0), refs_shape_or_data); + symbol::ShapeOrData single_dim_expr({symbol::DimExpr(1)}); + infer_context->SetShapeOrDataForValue( + op->result(1), symbol::ShapeOrDataDimExprs({single_dim_expr})); + + return true; +} + bool FullWithTensorOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(1); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 2d0eab8d15c74..7c625d82d66d6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -29,6 +29,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(EditDistance) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn) OP_DECLARE_INFER_SYMBOLIC_SHAPE(GroupNorm) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a3a602a1169c3..25419642a8ad3 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1446,6 +1446,7 @@ func : edit_distance data_type : DataType::FLOAT32 optional : hypslength, refslength + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : eig args: (Tensor x) From 365b88964df467cafef6ae4f139de4d1aeb1889c Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:10:29 +0800 Subject: [PATCH 2/5] Try to resolve review suggestions --- .../infer_symbolic_shape/multiary_infer_sym.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 0ecfd59e11b99..a2f979686a3b0 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -668,6 +668,8 @@ bool EditDistanceOpInferSymbolicShape( hyps_dims[0], refs_dims[0])); + infer_context->AddEqualCstr(hyps_dims[0], refs_dims[0]); + bool has_lengths = op->operand_source(2) && op->operand_source(3); if (has_lengths) { const auto &hypslength_shape_or_data = @@ -680,22 +682,37 @@ bool EditDistanceOpInferSymbolicShape( hyps_dims[0], phi::errors::InvalidArgument("Input(HypsLength) and Input(Hyps) should " "have identical first dimension.")); + + infer_context->AddEqualCstr(hypslength_shape_or_data.shape()[0], + hyps_dims[0]); + PADDLE_ENFORCE_EQ( refslength_shape_or_data.shape()[0], refs_dims[0], phi::errors::InvalidArgument("Input(RefsLength) and Input(Refs) should " "have identical first dimension.")); + + infer_context->AddEqualCstr(refslength_shape_or_data.shape()[0], + refs_dims[0]); + } else { PADDLE_ENFORCE_EQ( hyps_dims[1], 1, phi::errors::InvalidArgument("Input(Hyps) must be a 2-D Tensor with " "the 2nd dimension equal to 1.")); + + pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); + infer_context->AddEqualCstr(hyps_dims[1], one_dim_expr); + PADDLE_ENFORCE_EQ( refs_dims[1], 1, phi::errors::InvalidArgument("Input(Refs) must be a 2-D Tensor with " "the 2nd dimension equal to 1.")); + + pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); + infer_context->AddEqualCstr(hyps_dims[1], one_dim_expr); } infer_context->SetShapeOrDataForValue(op->result(0), refs_shape_or_data); From 9349e618718a67279d3bc7bf16d074c41ae7a549 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:16:53 +0800 Subject: [PATCH 3/5] Fixed some typos --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index a2f979686a3b0..cde429e1e089a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -712,7 +712,7 @@ bool EditDistanceOpInferSymbolicShape( "the 2nd dimension equal to 1.")); pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); - infer_context->AddEqualCstr(hyps_dims[1], one_dim_expr); + infer_context->AddEqualCstr(refs_dims[1], one_dim_expr); } infer_context->SetShapeOrDataForValue(op->result(0), refs_shape_or_data); From 15464c418efe51213a369f2acb8f9907f06a9f7b Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:14:32 +0800 Subject: [PATCH 4/5] Removed unnecessary PADDLE_ENFORCE_EQ --- .../multiary_infer_sym.cc | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index cde429e1e089a..82938386c842d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -659,14 +659,6 @@ bool EditDistanceOpInferSymbolicShape( phi::errors::InvalidArgument( "Input(Refs) must be a 2-D Tensor, but received rank %u.", refs_dims.size())); - PADDLE_ENFORCE_EQ( - hyps_dims[0], - refs_dims[0], - phi::errors::InvalidArgument("The first dimension of Input(Hyps) and " - "Input(Refs) must be the same, " - "but received: %d vs %d.", - hyps_dims[0], - refs_dims[0])); infer_context->AddEqualCstr(hyps_dims[0], refs_dims[0]); @@ -677,40 +669,16 @@ bool EditDistanceOpInferSymbolicShape( const auto &refslength_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(3)); - PADDLE_ENFORCE_EQ( - hypslength_shape_or_data.shape()[0], - hyps_dims[0], - phi::errors::InvalidArgument("Input(HypsLength) and Input(Hyps) should " - "have identical first dimension.")); - infer_context->AddEqualCstr(hypslength_shape_or_data.shape()[0], hyps_dims[0]); - PADDLE_ENFORCE_EQ( - refslength_shape_or_data.shape()[0], - refs_dims[0], - phi::errors::InvalidArgument("Input(RefsLength) and Input(Refs) should " - "have identical first dimension.")); - infer_context->AddEqualCstr(refslength_shape_or_data.shape()[0], refs_dims[0]); } else { - PADDLE_ENFORCE_EQ( - hyps_dims[1], - 1, - phi::errors::InvalidArgument("Input(Hyps) must be a 2-D Tensor with " - "the 2nd dimension equal to 1.")); - pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); infer_context->AddEqualCstr(hyps_dims[1], one_dim_expr); - PADDLE_ENFORCE_EQ( - refs_dims[1], - 1, - phi::errors::InvalidArgument("Input(Refs) must be a 2-D Tensor with " - "the 2nd dimension equal to 1.")); - pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); infer_context->AddEqualCstr(refs_dims[1], one_dim_expr); } From 3d0311426c6bb3863214144f4db9cc552bd71b33 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Tue, 6 Aug 2024 13:56:13 +0800 Subject: [PATCH 5/5] Replace with symbol --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 82938386c842d..783810d0f82d5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -676,10 +676,10 @@ bool EditDistanceOpInferSymbolicShape( refs_dims[0]); } else { - pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); + symbol::DimExpr one_dim_expr = symbol::DimExpr::CreateConstant(1); infer_context->AddEqualCstr(hyps_dims[1], one_dim_expr); - pir::DimExpr one_dim_expr = pir::DimExpr::CreateConstant(1); + symbol::DimExpr one_dim_expr = symbol::DimExpr::CreateConstant(1); infer_context->AddEqualCstr(refs_dims[1], one_dim_expr); }