diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 8ef16fb55057fd..f11d66e1911f8c 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -941,9 +941,12 @@ class SigmoidOpPattern : public pir::OpRewritePattern { public: using pir::OpRewritePattern::OpRewritePattern; + bool Match(paddle::dialect::SigmoidOp op) const override { + return !CompatibleInfo::IsDeniedForCinn(*op.operation()); + } - bool MatchAndRewrite(paddle::dialect::SigmoidOp op, - pir::PatternRewriter &rewriter) const override { + void Rewrite(paddle::dialect::SigmoidOp op, + pir::PatternRewriter &rewriter) const override { auto input_dtype = paddle::dialect::TransToPhiDataType( op->operand_source(0) .type() @@ -976,10 +979,7 @@ class SigmoidOpPattern } rewriter.ReplaceAllUsesWith(op.result(0), div); - rewriter.EraseOp(op); - - return true; } }; class GatherOpPattern @@ -987,35 +987,33 @@ class GatherOpPattern public: using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::GatherOp op, - pir::PatternRewriter &rewriter) const override { + bool Match(paddle::dialect::GatherOp op) const override { + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); + auto axis_gen_op = op->operand_source(2).defining_op(); + auto full_op = axis_gen_op->dyn_cast(); + return !is_denied && full_op; + } + + void Rewrite(paddle::dialect::GatherOp op, + pir::PatternRewriter &rewriter) const override { auto gather_op = op->dyn_cast(); auto x = op.operand_source(0); auto index = op->operand_source(1); const int axis = [&]() -> int { - int axis = 0; - if (gather_op->attributes().count("index")) { - axis = - gather_op.attribute("index").dyn_cast().data(); - } else { - auto axis_gen_op = op.operand_source(2).defining_op(); - PADDLE_ENFORCE_EQ(axis_gen_op->isa(), - true, - ::phi::errors::InvalidArgument( - "Not Supported: The gather operator for CINN " - "only supports constant value")); - auto full_op = axis_gen_op->dyn_cast(); - axis = static_cast(full_op.attribute("value") - .dyn_cast<::pir::FloatAttribute>() - .data()); - return axis; - } + auto axis_gen_op = op.operand_source(2).defining_op(); + PADDLE_ENFORCE_EQ(axis_gen_op->isa(), + true, + ::phi::errors::InvalidArgument( + "Not Supported: The gather operator for CINN " + "only supports constant value")); + auto full_op = axis_gen_op->dyn_cast(); + return static_cast( + full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data()); }(); auto out = rewriter.Build(x, index, axis)->result(0); rewriter.ReplaceAllUsesWith(op->result(0), out); rewriter.EraseOp(op); - return true; } };