Skip to content

Commit

Permalink
[bugfix] to concat input squash (#39593)
Browse files Browse the repository at this point in the history
* fix and add more tests

* remove unwanted changes

* check only concat and elementwise

* move check to a function

* add todo comment

* Revert "fix ptq fc attr name fuse_activation->activation_type"

This reverts commit ffd0233.
  • Loading branch information
sfraczek authored Feb 17, 2022
1 parent 2d2f11d commit f29da15
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 15 deletions.
25 changes: 22 additions & 3 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ bool CPUQuantizeSquashPass::IsDequantizeInputUint8(
return false;
}

bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible(
Node* quant_op, Node* dequant_in, Node* next_op) const {
bool is_concat_signed =
quant_op->Op()->GetAttrIfExists<bool>("is_negative_input");
bool is_input_unsigned = IsDequantizeInputUint8(dequant_in);
/* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN
kernel will support two different input data types */
bool is_next_op_concat_or_elementwise =
next_op->Op()->Type() == "concat" ||
next_op->Op()->Type().find("elementwise") == 0;
if (is_next_op_concat_or_elementwise && is_concat_signed &&
is_input_unsigned) {
VLOG(4) << "Do not squash dequant-quant, because "
<< "next_op is: " << next_op->Op()->Type()
<< ", is_concat_signed: " << is_concat_signed
<< ", is_input_unsigned: " << is_input_unsigned << ".";
return true;
}
return false;
}

void CPUQuantizeSquashPass::DequantQuantSquash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
Expand All @@ -151,9 +172,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);

// Don't squash if e.g. just one concat input is unsigned
if (IsDequantizeInputUint8(dequant_in) &&
!quant_op->Op()->GetAttrIfExists<bool>("is_negative_input")) {
if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) {
return;
}

Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ class CPUQuantizeSquashPass : public FusePassBase {
*/
bool IsDequantizeInputUint8(const Node* dequant_in) const;

/*
* Don't squash unsigned dequantize with signed quantize.
* This is important for concat and elementwise ops.
* When inputs have different sign, concat will assume signed type and
* elementwise assumes first input type.
*/
bool IsDequantizeQuantizeIncompatible(Node* quant_op, Node* dequant_in,
Node* next_op) const;

/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
Expand Down
131 changes: 124 additions & 7 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include <gtest/gtest.h>

#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"

Expand Down Expand Up @@ -234,11 +235,70 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
return prog;
}

/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->relu->j->Dequant->k(u8)->Quant->l-/
*/
ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out});

SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out},
0.0f, "float32", false, 1, false); // is_negative_input = false

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}

/* a->relu->b->Dequant->c(u8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out});

SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}

/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
Expand All @@ -255,8 +315,35 @@ ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) {
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out},
0.0, "float32", false, 1, false);
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
return prog;
}

/* a->pool2d->b->Dequant->c(s8)->Quant->d-\
* e->pool2d->f->Dequant->g(s8)->Quant->h--Concat1->x
* i->pool2d->j->Dequant->k(s8)->Quant->l-/
*/
ProgramDesc BuildS8S8S8ConcatProgramDesc(float scale_out, float scale) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d2", {"e"}, {"f"}, true, {scale, scale_out});
SetOp(&prog, "pool2d", "Pool2d3", {"i"}, {"j"}, true, {scale, scale_out});

SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true,
{scale, scale_out});
SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true,
{scale, scale_out});

SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out});
SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out});

SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true);
Expand Down Expand Up @@ -834,16 +921,46 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) {
remove_nodes);
}

TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) {
TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat1) {
// removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
{"quantize", 1},
{"dequantize", 1},
{"relu", 1},
{"pool2d", 2}};
CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f),
expected_operators, remove_nodes);
CheckNodesTest(BuildS8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat2) {
// removed 1 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 4;
std::unordered_map<std::string, int> expected_operators = {{"concat", 1},
{"quantize", 2},
{"dequantize", 2},
{"relu", 2},
{"pool2d", 1}};
CheckNodesTest(BuildU8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"pool2d", 3}};
CheckNodesTest(BuildS8S8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) {
// removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out)
auto remove_nodes = 12;
std::unordered_map<std::string, int> expected_operators = {
{"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}};
CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators,
remove_nodes);
}

} // namespace ir
Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/inference/api/mkldnn_quantizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,11 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
// force unsigned type if already know it
bool is_unsigned = false;
bool compute_scale = true;
if (op->Type() == "conv2d") {
if (op->Type() == "conv2d" || op->Type() == "fc") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation");
is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "fc") {
std::string activation_type =
op->GetAttrIfExists<std::string>("activation_type");
is_unsigned = (activation_type == "relu" || activation_type == "relu6");
} else if (op->Type() == "relu") {
is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||
Expand Down

0 comments on commit f29da15

Please sign in to comment.