From 9225bd8fe3fb367fa04fc9edaa2e4c7ddc515c0c Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 29 Apr 2024 13:12:22 +0800 Subject: [PATCH 1/3] [CINN]Fix GatherOpPattern in pd_to_cinn_pass --- .../operator/transforms/pd_to_cinn_pass.cc | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 8ef16fb55057f..962279f922d94 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -941,9 +941,12 @@ class SigmoidOpPattern : public pir::OpRewritePattern { public: using pir::OpRewritePattern::OpRewritePattern; + bool Match(paddle::dialect::SigmoidOp op) const override { + return !CompatibleInfo::IsDeniedForCinn(*op.operation()); + } - bool MatchAndRewrite(paddle::dialect::SigmoidOp op, - pir::PatternRewriter &rewriter) const override { + void Rewrite(paddle::dialect::SigmoidOp op, + pir::PatternRewriter &rewriter) const override { auto input_dtype = paddle::dialect::TransToPhiDataType( op->operand_source(0) .type() @@ -987,8 +990,15 @@ class GatherOpPattern public: using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::GatherOp op, - pir::PatternRewriter &rewriter) const override { + bool Match(paddle::dialect::GatherOp op) const override { + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); + auto axis_gen_op = op->operand_source(2).defining_op(); + auto full_op = axis_gen_op->dyn_cast(); + return !is_denied && full_op; + } + + void Rewrite(paddle::dialect::GatherOp op, + pir::PatternRewriter &rewriter) const override { auto gather_op = op->dyn_cast(); auto x = op.operand_source(0); auto index = op->operand_source(1); From 52cd5ba33cce4730788000da233187f86b9ad574 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 29 Apr 2024 13:56:56 +0800 Subject: [PATCH 2/3] fix return --- .../cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 962279f922d94..f2ec386247e1c 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -979,10 +979,7 @@ class SigmoidOpPattern } rewriter.ReplaceAllUsesWith(op.result(0), div); - rewriter.EraseOp(op); - - return true; } }; class GatherOpPattern @@ -1025,7 +1022,6 @@ class GatherOpPattern rewriter.Build(x, index, axis)->result(0); rewriter.ReplaceAllUsesWith(op->result(0), out); rewriter.EraseOp(op); - return true; } }; From 1a13325003e35bf76b333b95d6c0f5c40e290859 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 29 Apr 2024 14:25:56 +0800 Subject: [PATCH 3/3] fix attribute --- .../operator/transforms/pd_to_cinn_pass.cc | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index f2ec386247e1c..f11d66e1911f8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -1000,23 +1000,15 @@ class GatherOpPattern auto x = op.operand_source(0); auto index = op->operand_source(1); const int axis = [&]() -> int { - int axis = 0; - if (gather_op->attributes().count("index")) { - axis = - gather_op.attribute("index").dyn_cast().data(); - } else { - auto axis_gen_op = op.operand_source(2).defining_op(); - PADDLE_ENFORCE_EQ(axis_gen_op->isa(), - true, - ::phi::errors::InvalidArgument( - "Not Supported: The gather operator for CINN " - "only supports constant value")); - auto full_op = axis_gen_op->dyn_cast(); - axis = static_cast(full_op.attribute("value") - .dyn_cast<::pir::FloatAttribute>() - .data()); - return axis; - } + auto axis_gen_op = op.operand_source(2).defining_op(); + PADDLE_ENFORCE_EQ(axis_gen_op->isa(), + true, + ::phi::errors::InvalidArgument( + "Not Supported: The gather operator for CINN " + "only supports constant value")); + auto full_op = axis_gen_op->dyn_cast(); + return static_cast( + full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data()); }(); auto out = rewriter.Build(x, index, axis)->result(0);