Skip to content

Commit

Permalink
[Inference] Add the quant_linear_fuse_pass (#58637)
Browse files Browse the repository at this point in the history
* add the quant_linear_fuse_pass
  • Loading branch information
Wanglongzhi2001 authored Nov 22, 2023
1 parent 4b05d5a commit 6747af6
Show file tree
Hide file tree
Showing 6 changed files with 783 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(trt_delete_weight_dequant_linear_op_pass inference)
pass_library(delete_op_device_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(quant_linear_fuse_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_assign_op_pass inference)
pass_library(delete_dropout_op_pass inference)
Expand Down
109 changes: 109 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3331,6 +3331,115 @@ void patterns::DeleteWeightDequantLinearOpEncoderPattern::operator()() {
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}

PDNode *patterns::QuantLinearFusePattern::operator()(bool with_bias,
bool with_relu) {
auto *quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("quantize_linear", "X");

auto *quantize_linear_op_scale =
pattern->NewNode(quantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("quantize_linear", "Scale")
->assert_is_persistable_var();

auto *quantize_linear_op = pattern->NewNode(quantize_linear_op_repr())
->assert_is_op("quantize_linear");

auto *quantize_linear_op_out =
pattern->NewNode(quantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("quantize_linear", "Y")
->assert_is_op_input("dequantize_linear", "X")
->assert_var_not_persistable();

auto *dequantize_linear_op = pattern->NewNode(dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");

auto *dequantize_linear_op_out =
pattern->NewNode(dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y")
->AsOutput();
// Add links.
quantize_linear_op
->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale})
.LinksTo({quantize_linear_op_out});
dequantize_linear_op->LinksFrom({quantize_linear_op_out})
.LinksTo({dequantize_linear_op_out});

auto *weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();

auto *weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();

auto *weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");

auto *weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y")
->assert_is_op_input("matmul_v2", "Y");

// Add links.
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});

auto *mul = pattern->NewNode(mul_repr())->assert_is_op("matmul_v2");

auto *mul_out =
pattern->NewNode(mul_out_repr())->assert_is_op_output("matmul_v2");

// Add links.
mul->LinksFrom({dequantize_linear_op_out, weight_dequantize_linear_op_out})
.LinksTo({mul_out});

if (!with_bias) { // not with bias
return mul_out;
} else { // with bias
mul_out->AsIntermediate()->assert_is_op_input("elementwise_add", "X");

auto *elementwise_add = pattern->NewNode(elementwise_add_repr())
->assert_is_op("elementwise_add");

auto *bias = pattern->NewNode(bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();

auto *elementwise_add_out =
pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");

elementwise_add->LinksFrom({mul_out, bias}).LinksTo({elementwise_add_out});

if (!with_relu) {
return elementwise_add_out;
} else {
elementwise_add_out->AsIntermediate()->assert_is_op_input("relu");
// Create operators.
auto *relu = pattern->NewNode(relu_repr())->assert_is_op("relu");
auto *relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");

relu->LinksFrom({elementwise_add_out}).LinksTo({relu_out});
return relu_out;
}
}
}

void patterns::DeleteWeightDequantLinearOpDecoderPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,36 @@ struct DeleteWeightDequantLinearOpEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2);
};

struct QuantLinearFusePattern : public PatternBase {
QuantLinearFusePattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_linear_fuse_pattern") {}

PDNode* operator()(bool with_bias, bool with_relu);

PATTERN_DECL_NODE(quantize_linear_op_x);
PATTERN_DECL_NODE(quantize_linear_op_scale);
PATTERN_DECL_NODE(quantize_linear_op);
PATTERN_DECL_NODE(quantize_linear_op_out);

PATTERN_DECL_NODE(dequantize_linear_op);
PATTERN_DECL_NODE(dequantize_linear_op_out);

PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);

PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(mul_out);

PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(elementwise_add);
PATTERN_DECL_NODE(elementwise_add_out);

PATTERN_DECL_NODE(relu);
PATTERN_DECL_NODE(relu_out);
};

struct DeleteWeightDequantLinearOpDecoderPattern : public PatternBase {
DeleteWeightDequantLinearOpDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
Expand Down
Loading

0 comments on commit 6747af6

Please sign in to comment.