Skip to content

Commit

Permalink
Quantize elementwise mul (PaddlePaddle#40546)
Browse files Browse the repository at this point in the history
* Quantize elementwise mul op

* Parametrize elementwise functions

* Fix code formatting
  • Loading branch information
Zuza authored and liqitong-a committed Mar 17, 2022
1 parent 6715e09 commit 0b70359
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 128 deletions.
21 changes: 11 additions & 10 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() {
return output_var;
}

PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");

x_var->AsInput()->assert_is_op_input("elementwise_add", "X");
y_var->AsInput()->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
const std::string elementwise_type) {
auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);

x_var->AsInput()->assert_is_op_input(elementwise_type, "X");
y_var->AsInput()->assert_is_op_input(elementwise_type, "Y");
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
->assert_is_op_output(elementwise_type, "Out");

elementwise_add_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var});
elementwise_op->LinksFrom({x_var, y_var});
elementwise_op->LinksTo({out_var});

return out_var;
}
Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1016,20 +1016,20 @@ struct Pool : public PatternBase {
PATTERN_DECL_NODE(pool_output);
};

// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
// connection fusion is on.
struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}

PDNode* operator()(PDNode* x_var, PDNode* y_var);

PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
// Elementwise ops
// Forward pass for element-wise operators (add, mul)
// elementwise_mul_out is the result of the operator
struct Elementwise : public PatternBase {
Elementwise(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise") {}

PDNode* operator()(PDNode* x_var, PDNode* y_var,
const std::string elementwise_type);

PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_x);
PATTERN_DECL_NODE(elementwise_y);
PATTERN_DECL_NODE(elementwise_out);
};

// Transpose op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern();

patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
elementwise_add_pattern(
conv_output,
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_pattern(
conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
"elementwise_add");
conv_output->AsIntermediate();

int found_conv_as_x_count = 0;
Expand All @@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);

if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;

if (!IsReachable(g, elementwise_add_identity, conv_output)) return;
if (!IsReachable(g, elementwise_identity, conv_output)) return;

if (HasFusedActivation(conv_op)) return;

Expand All @@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
return;
}

conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);

GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op});
GraphSafeRemoveNodes(g, {conv_output, elementwise_op});

IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
IR_NODE_LINK_TO(elementwise_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_out);

found_conv_as_x_count++;
};
Expand All @@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern();

patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
elementwise_add_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
conv_output);
patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_pattern(
pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output,
"elementwise_add");
conv_output->AsIntermediate();

int found_conv_as_y_count = 0;
Expand All @@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);

if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;

if (!IsReachable(g, elementwise_add_x, conv_output)) return;
if (!IsReachable(g, elementwise_x, conv_output)) return;

if (HasFusedActivation(conv_op)) return;

Expand All @@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
return;
}

conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);

GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op});
GraphSafeRemoveNodes(g, {conv_output, elementwise_op});

IR_NODE_LINK_TO(elementwise_add_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
IR_NODE_LINK_TO(elementwise_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_out);

found_conv_as_y_count++;
};
Expand Down Expand Up @@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
patterns::Conv conv_y_pattern{pattern, name_scope};
auto conv_y_output = conv_y_pattern();

patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
elementwise_add_pattern(conv_x_output, conv_y_output);
patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add");
conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate();

Expand All @@ -301,19 +301,19 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern);

GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);

if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}

if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_x_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_op) != FUSE_MKLDNN) return;

Node* projection_node;
Node* residual_conv_op;
Expand All @@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return;

residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});

residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);

GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op});
GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_op});

IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
IR_NODE_LINK_TO(residual_conv_op, elementwise_out);

found_projection_conv_count++;
};
Expand Down
71 changes: 36 additions & 35 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count);
}

void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const {
void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
patterns::Elementwise elementwise_pattern{pattern, name_scope_};

elementwise_add_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
elementwise_pattern(
pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);

int quantize_elementwise_add_count = 0;
int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize elementwise_add op";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
VLOG(4) << "Quantize " + elementwise_type + " op";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);

// skip if should not be quantized
if (!platform::HasOpINT8DataType(elementwise_add_op->Op())) {
LogQuantizationDisabled(elementwise_add_op);
if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
LogQuantizationDisabled(elementwise_op);
return;
}

GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);

if (!AreScalesPresentForNodes(
{elementwise_add_x, elementwise_add_y, elementwise_add_out})) {
LogCannotQuantizeOp(elementwise_add_op,
{elementwise_x, elementwise_y, elementwise_out})) {
LogCannotQuantizeOp(elementwise_op,
"No scale available for the operator");
return;
}

bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale =
GetScaleValueForNode(elementwise_add_x, &is_x_unsigned);
auto input_y_scale =
GetScaleValueForNode(elementwise_add_y, &is_y_unsigned);
auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);

// TODO(sfraczek): add support for different signness
if (is_x_unsigned != is_y_unsigned) {
LogCannotQuantizeOp(elementwise_add_op,
"ElementwiseAdd inputs must be of the same type.");
LogCannotQuantizeOp(elementwise_op,
"Elementwise inputs must be of the same type.");
return;
}

QuantizeInput(g, elementwise_add_op, elementwise_add_x, "X", input_x_scale,
QuantizeInput(g, elementwise_op, elementwise_x, "X", input_x_scale,
is_x_unsigned, "Scale_x");
QuantizeInput(g, elementwise_add_op, elementwise_add_y, "Y", input_y_scale,
QuantizeInput(g, elementwise_op, elementwise_y, "Y", input_y_scale,
is_y_unsigned, "Scale_y");

bool is_output_unsigned{false};
auto output_scale =
GetScaleValueForNode(elementwise_add_out, &is_output_unsigned);
GetScaleValueForNode(elementwise_out, &is_output_unsigned);

DequantizeOutput(g, elementwise_add_op, elementwise_add_out, "Out",
output_scale, is_output_unsigned, "Scale_out");
DequantizeOutput(g, elementwise_op, elementwise_out, "Out", output_scale,
is_output_unsigned, "Scale_out");

++quantize_elementwise_add_count;
++quantize_elementwise_count;
};
gpd(graph, handler);
AddStatis(quantize_elementwise_add_count);
AddStatis(quantize_elementwise_count);

PrettyLogDetail("--- quantized %d elementwise_add ops",
quantize_elementwise_add_count);
PrettyLogDetail("--- quantized %d %s ops", quantize_elementwise_count,
elementwise_type);
}

void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
Expand Down Expand Up @@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeFc(graph);
QuantizeReshape(graph);
QuantizeMatmul(graph);
QuantizeElementwiseAdd(graph);
QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul");
QuantizeFusionGru(graph);
QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph);
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeTranspose(Graph* graph) const;
void QuantizeReshape(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwiseAdd(Graph* graph) const;
void QuantizeElementwise(Graph* graph,
const std::string elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const;
void QuantizeMultiGru(Graph* graph) const;
void QuantizeFusionLSTM(Graph* graph) const;
Expand Down
Loading

0 comments on commit 0b70359

Please sign in to comment.