diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 64f9dfdc0801a..a94ca438449fa 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -132,6 +132,27 @@ bool CPUQuantizeSquashPass::IsDequantizeInputUint8( return false; } +bool CPUQuantizeSquashPass::IsDequantizeQuantizeIncompatible( + Node* quant_op, Node* dequant_in, Node* next_op) const { + bool is_concat_signed = + quant_op->Op()->GetAttrIfExists("is_negative_input"); + bool is_input_unsigned = IsDequantizeInputUint8(dequant_in); + /* TODO(sfraczek): remove elementwise from this condition when BinaryMKLDNN + kernel will support two different input data types */ + bool is_next_op_concat_or_elementwise = + next_op->Op()->Type() == "concat" || + next_op->Op()->Type().find("elementwise") == 0; + if (is_next_op_concat_or_elementwise && is_concat_signed && + is_input_unsigned) { + VLOG(4) << "Do not squash dequant-quant, because " + << "next_op is: " << next_op->Op()->Type() + << ", is_concat_signed: " << is_concat_signed + << ", is_input_unsigned: " << is_input_unsigned << "."; + return true; + } + return false; +} + void CPUQuantizeSquashPass::DequantQuantSquash( Graph* graph, std::unordered_map* nodes_keep_counter) const { @@ -151,9 +172,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); - // Don't squash if e.g. just one concat input is unsigned - if (IsDequantizeInputUint8(dequant_in) && - !quant_op->Op()->GetAttrIfExists("is_negative_input")) { + if (IsDequantizeQuantizeIncompatible(quant_op, dequant_in, next_op)) { return; } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index d668c222a4ecd..7cf716c10e3d5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -48,6 +48,15 @@ class CPUQuantizeSquashPass : public FusePassBase { */ bool IsDequantizeInputUint8(const Node* dequant_in) const; + /* + * Don't squash unsigned dequantize with signed quantize. + * This is important for concat and elementwise ops. + * When inputs have different sign, concat will assume signed type and + * elementwise assumes first input type. + */ + bool IsDequantizeQuantizeIncompatible(Node* quant_op, Node* dequant_in, + Node* next_op) const; + /* * Squash dequantize-quantize ops pairs into requantize or nothing */ diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 4077700b28f2e..e00bb84e35c09 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h" #include + +#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h" #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/platform/place.h" @@ -234,11 +235,70 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, return prog; } +/* a->relu->b->Dequant->c(u8)->Quant->d-\ + * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x + * i->relu->j->Dequant->k(u8)->Quant->l-/ + */ +ProgramDesc BuildU8U8U8ConcatProgramDesc(float scale_out, float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "relu", "Relu1", {"a"}, {"b"}, true, {scale, scale_out}); + SetOp(&prog, "relu", "Relu2", {"e"}, {"f"}, true, {scale, scale_out}); + SetOp(&prog, "relu", "Relu3", {"i"}, {"j"}, true, {scale, scale_out}); + + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, + {scale, scale_out}); + + SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out}, + 0.0f, "float32", false, 1, false); // is_negative_input = false + SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out}, + 0.0f, "float32", false, 1, false); // is_negative_input = false + SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out}, + 0.0f, "float32", false, 1, false); // is_negative_input = false + + SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); + return prog; +} + +/* a->relu->b->Dequant->c(u8)->Quant->d-\ + * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x + * i->pool2d->j->Dequant->k(s8)->Quant->l-/ + */ +ProgramDesc BuildU8U8S8ConcatProgramDesc(float scale_out, float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "relu", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out}); + SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out}); + SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out}); + + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, + {scale, scale_out}); + + SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out}); + SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out}); + SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out}); + + SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); + return prog; +} + /* a->pool2d->b->Dequant->c(s8)->Quant->d-\ * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x * i->pool2d->j->Dequant->k(s8)->Quant->l-/ */ -ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) { +ProgramDesc BuildS8U8S8ConcatProgramDesc(float scale_out, float scale) { ProgramDesc prog; for (auto& v : variable_names) { prog.MutableBlock(0)->Var(v); @@ -255,8 +315,35 @@ ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) { {scale, scale_out}); SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out}); - SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out}, - 0.0, "float32", false, 1, false); + SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out}); + SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out}); + + SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); + return prog; +} + +/* a->pool2d->b->Dequant->c(s8)->Quant->d-\ + * e->pool2d->f->Dequant->g(s8)->Quant->h--Concat1->x + * i->pool2d->j->Dequant->k(s8)->Quant->l-/ + */ +ProgramDesc BuildS8S8S8ConcatProgramDesc(float scale_out, float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out}); + SetOp(&prog, "pool2d", "Pool2d2", {"e"}, {"f"}, true, {scale, scale_out}); + SetOp(&prog, "pool2d", "Pool2d3", {"i"}, {"j"}, true, {scale, scale_out}); + + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, + {scale, scale_out}); + + SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out}); + SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out}); SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out}); SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); @@ -834,7 +921,7 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) { remove_nodes); } -TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) { +TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat1) { // removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) auto remove_nodes = 8; std::unordered_map expected_operators = {{"concat", 1}, @@ -842,8 +929,38 @@ TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) { {"dequantize", 1}, {"relu", 1}, {"pool2d", 2}}; - CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f), - expected_operators, remove_nodes); + CheckNodesTest(BuildS8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators, + remove_nodes); +} + +TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat2) { + // removed 1 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) + auto remove_nodes = 4; + std::unordered_map expected_operators = {{"concat", 1}, + {"quantize", 2}, + {"dequantize", 2}, + {"relu", 2}, + {"pool2d", 1}}; + CheckNodesTest(BuildU8U8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators, + remove_nodes); +} + +TEST(CpuQuantizeSquashPass, squash_all_s8_input_to_concat1) { + // removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) + auto remove_nodes = 12; + std::unordered_map expected_operators = { + {"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"pool2d", 3}}; + CheckNodesTest(BuildS8S8S8ConcatProgramDesc(1.2f, 1.2f), expected_operators, + remove_nodes); +} + +TEST(CpuQuantizeSquashPass, squash_all_u8_input_to_concat2) { + // removed 3 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) + auto remove_nodes = 12; + std::unordered_map expected_operators = { + {"concat", 1}, {"quantize", 0}, {"dequantize", 0}, {"relu", 3}}; + CheckNodesTest(BuildU8U8U8ConcatProgramDesc(1.2f, 1.2f), expected_operators, + remove_nodes); } } // namespace ir diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 9d22e1b4b520c..ef9d03d1dcbaf 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -116,15 +116,11 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( // force unsigned type if already know it bool is_unsigned = false; bool compute_scale = true; - if (op->Type() == "conv2d") { + if (op->Type() == "conv2d" || op->Type() == "fc") { // output of conv2d with relu must be unsigned std::string fuse_activation = op->GetAttrIfExists("fuse_activation"); is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6"); - } else if (op->Type() == "fc") { - std::string activation_type = - op->GetAttrIfExists("activation_type"); - is_unsigned = (activation_type == "relu" || activation_type == "relu6"); } else if (op->Type() == "relu") { is_unsigned = true; } else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||